using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;
using Pfz.Caching;
using Pfz.Extensions;
using Pfz.Serialization;
using Pfz.Threading;
namespace Pfz.Remoting
{
/// <summary>
/// Class responsible for creating many channels inside another stream.
/// This is used by the remoting framework, so each thread has it's own
/// channel inside a single tcp/ip connection.
/// </summary>
public class StreamChanneller:
ThreadSafeExceptionAwareDisposable,
IChanneller,
IGarbageCollectionAware
{
#region Private and internal fields
private Stream _stream;
private Dictionary<int, StreamChannel> _channels = new Dictionary<int, StreamChannel>();
private Dictionary<int, ManagedManualResetEvent> _awaitingChannels = new Dictionary<int, ManagedManualResetEvent>();
private readonly BinarySerializer _mainChannelSerializer = _CreateSerializer();
private StreamChannel _mainChannel;
internal int _channelBufferSize;
internal readonly Queue<byte[]> _buffersToSend = new Queue<byte[]>();
private int _nextChannelId;
#endregion
#region Constructors
/// <summary>
/// Creates the channeller for the specified stream and allows you to
/// specify the buffer size. For tcp/ip stream, use the bigger value
/// between receive an send buffer size.
/// </summary>
public StreamChanneller(Stream stream, int bufferSizePerChannel, string localEndpoint, string remoteEndpoint)
{
if (stream == null)
throw new ArgumentNullException("stream");
if (bufferSizePerChannel < 256)
throw new ArgumentException("bufferSizePerChannel can't be less than 256 bytes", "bufferSizePerChannel");
_stream = stream;
_channelBufferSize = bufferSizePerChannel;
LocalEndpoint = localEndpoint;
RemoteEndpoint = remoteEndpoint;
}
private bool _started;
/// <summary>
/// Starts this channeller.
/// </summary>
public void Start()
{
CheckUndisposed();
if (_started)
throw new RemotingException("This channeller is already started.");
_started = true;
StreamChannel mainChannel = new StreamChannel(this);
_mainChannel = mainChannel;
_channels.Add(0, mainChannel);
Thread threadReader = new Thread(_Reader);
threadReader.IsBackground = true;
threadReader.Name = "StreamChanneller reader.";
threadReader.Start();
Thread threadWriter = new Thread(_Writer);
threadWriter.IsBackground = true;
threadWriter.Name = "StreamChanneller writer.";
threadWriter.Start();
Thread threadMainChannel = new Thread(_MainChannel);
threadMainChannel.IsBackground = true;
threadMainChannel.Name = "StreamChanneller main channel.";
threadMainChannel.Start();
GCUtils.RegisterForCollectedNotification(this);
}
#endregion
#region Dispose
/// <summary>
/// Disposes the channeller and the stream.
/// </summary>
/// <param name="disposing">true if called from Dispose() and false if called from destructor.</param>
[SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_writerEvent")]
protected override void Dispose(bool disposing)
{
if (disposing)
{
GCUtils.UnregisterFromCollectedNotification(this);
Disposer.Dispose(ref _stream);
lock(DisposeLock)
Monitor.Pulse(DisposeLock);
var channels = _channels;
if (channels != null)
{
_channels = null;
foreach(StreamChannel channel in channels.Values)
channel.Dispose(DisposeException);
}
var awaitingChannels = _awaitingChannels;
if (awaitingChannels != null)
{
_awaitingChannels = null;
foreach(var mre in awaitingChannels.Values)
mre.Dispose();
}
}
base.Dispose(disposing);
if (disposing)
{
var disposedHandler = Disposed;
if (disposedHandler != null)
disposedHandler(this, EventArgs.Empty);
}
}
#endregion
#region _Collected
void IGarbageCollectionAware.OnCollected()
{
try
{
lock(DisposeLock)
{
if (WasDisposed)
{
GCUtils.UnregisterFromCollectedNotification(this);
return;
}
_awaitingChannels = new Dictionary<int, ManagedManualResetEvent>(_awaitingChannels);
_channels = new Dictionary<int, StreamChannel>(_channels);
_buffersToSend.TrimExcess();
}
}
catch(OutOfMemoryException)
{
// ignore out of memory exception, as lists are kept intact if there
// is no memory.
}
}
#endregion
#region Properties
/// <summary>
/// Gets the LocalEndpoint.
/// </summary>
public string LocalEndpoint { get; private set; }
/// <summary>
/// Gets the RemoteEndpoint.
/// </summary>
public string RemoteEndpoint { get; private set; }
#endregion
#region Methods
#region CreateChannel
/// <summary>
/// Creates a channel, sending the serializableData parameter to the
/// other side, so it can decide what to do with this channel before it
/// gets used (this avoids an extra tcp/ip packet for small information).
/// </summary>
/// <param name="serializableData">Data to send to the other side.</param>
/// <returns>A new channel.</returns>
public StreamChannel CreateChannel(object serializableData = null)
{
try
{
int channelId = Interlocked.Increment(ref _nextChannelId);
StreamChannel channel = new StreamChannel(this);
channel._id = channelId;
ClassChannelCreated channelCreated = new ClassChannelCreated();
channelCreated.SenderChannelId = channelId;
channelCreated.Data = serializableData;
using(var manualResetEvent = new ManagedManualResetEvent())
{
bool lockTaken = false;
try
{
Monitor.Enter(DisposeLock, ref lockTaken);
CheckUndisposed();
_channels.Add(channelId, channel);
_awaitingChannels.Add(channelId, manualResetEvent);
try
{
_mainChannelSerializer.Serialize(_mainChannel, channelCreated);
_mainChannel.Flush();
try
{
}
finally
{
Monitor.Exit(DisposeLock);
lockTaken = false;
}
manualResetEvent.WaitOne();
Monitor.Enter(DisposeLock, ref lockTaken);
CheckUndisposed();
}
finally
{
if (!lockTaken)
{
Monitor.Enter(DisposeLock);
lockTaken = true;
}
if (!WasDisposed)
_awaitingChannels.Remove(channelId);
}
}
finally
{
if (lockTaken)
Monitor.Exit(DisposeLock);
}
}
return channel;
}
catch(Exception exception)
{
Dispose(exception);
throw;
}
}
#endregion
#region _RemoveChannel
internal void _RemoveChannel(int id, int remoteId)
{
ChannelRemoved channelRemoved = new ChannelRemoved();
channelRemoved.ReceiverChannelId = remoteId;
_channels.Remove(id);
try
{
_mainChannelSerializer.Serialize(_mainChannel, channelRemoved);
_mainChannel.Flush();
}
catch
{
}
}
#endregion
#region _Reader
private void _Reader()
{
try
{
byte[] headerBuffer = new byte[8];
while(true)
{
_Read(headerBuffer, 8);
int channelId = BitConverter.ToInt32(headerBuffer, 0);
int messageSize = BitConverter.ToInt32(headerBuffer, 4);
StreamChannel channel = null;
lock(DisposeLock)
{
if (WasDisposed)
return;
_channels.TryGetValue(channelId, out channel);
}
if (channel == null)
{
_Discard(messageSize);
continue;
}
int bytesLeft = messageSize;
while (bytesLeft > 0)
{
if (WasDisposed)
break;
int count = bytesLeft;
if (bytesLeft > _channelBufferSize)
count = _channelBufferSize;
byte[] messageBuffer;
try
{
messageBuffer = new byte[count];
}
catch(Exception exception)
{
channel.Dispose(exception);
continue;
}
_Read(messageBuffer, count);
bytesLeft -= count;
lock(channel.DisposeLock)
{
if (!channel.WasDisposed)
{
try
{
channel._inMessages.Enqueue(messageBuffer);
}
catch(Exception exception)
{
channel.Dispose(exception);
continue;
}
Monitor.Pulse(channel.DisposeLock);
}
}
}
}
}
catch(Exception exception)
{
Dispose(exception);
}
}
#endregion
#region _Read
private void _Read(byte[] buffer, int count)
{
var stream = _stream;
if (stream == null)
{
var exception = new IOException("Stream closed.");
Dispose(exception);
throw exception;
}
int totalRead = 0;
while(totalRead < count)
{
int read = _stream.Read(buffer, totalRead, count-totalRead);
if (read == 0)
{
var exception = new IOException("Stream closed.");
Dispose(exception);
throw exception;
}
totalRead += read;
}
}
#endregion
#region _Discard
private void _Discard(int bytesToDiscard)
{
int bufferSize = Math.Min(bytesToDiscard, _channelBufferSize);
byte[] discardBuffer = new byte[bufferSize];
int bytesLeft = bytesToDiscard;
while(bytesLeft > 0)
{
if (bytesLeft < bufferSize)
{
_Read(discardBuffer, bytesLeft);
break;
}
_Read(discardBuffer, bufferSize);
bytesLeft -= bufferSize;
}
}
#endregion
#region _Writer
private void _Writer()
{
var stream = _stream;
if (stream == null)
return;
try
{
lock(DisposeLock)
{
if (WasDisposed)
return;
while (true)
{
if (_buffersToSend.Count == 0)
{
stream.Flush();
Monitor.Wait(DisposeLock);
if (WasDisposed)
return;
continue;
}
byte[] buffer = _buffersToSend.Dequeue();
stream.Write(buffer, 0, buffer.Length);
}
}
}
catch(Exception exception)
{
Dispose(exception);
}
}
#endregion
#region _MainChannel
private void _MainChannel()
{
var mainChannel = _mainChannel;
if (mainChannel == null)
return;
try
{
var serializer = _CreateSerializer();
while(true)
{
object obj = serializer.Deserialize(mainChannel);
var action = (IChannelAction)obj;
action.Run(this);
}
}
catch(Exception exception)
{
Dispose(exception);
}
}
#endregion
#region _CreateSerializer
private static BinarySerializer _CreateSerializer()
{
var serializer = new BinarySerializer();
serializer.Register(ClassChannelCreatedSerializer.Instance);
serializer.Register(ChannelAssociatedSerializer.Instance);
serializer.Register(ChannelRemovedSerializer.Instance);
return serializer;
}
#endregion
#endregion
#region Events
/// <summary>
/// Event called when Dispose() has just finished.
/// </summary>
public event EventHandler Disposed;
/// <summary>
/// Event that is invoked when the remote side creates a new channel.
/// </summary>
public event EventHandler<ChannelCreatedEventArgs> ChannelCreated;
#endregion
#region Nested classes
private interface IChannelAction
{
void Run(StreamChanneller channeller);
}
private sealed class ClassChannelCreated:
IChannelAction
{
internal int SenderChannelId;
internal object Data;
public void Run(StreamChanneller channeller)
{
int localChannelId = Interlocked.Increment(ref channeller._nextChannelId);
StreamChannel channel = new StreamChannel(channeller);
channel._id = localChannelId;
channel._remoteId = SenderChannelId;
ChannelAssociated associated = new ChannelAssociated();
associated.SenderChannelId = localChannelId;
associated.ReceiverChannelId = SenderChannelId;
lock(channeller.DisposeLock)
{
if (channeller.WasDisposed)
return;
channeller._channels.Add(localChannelId, channel);
channeller._mainChannelSerializer.Serialize(channeller._mainChannel, associated);
channeller._mainChannel.Flush();
}
ChannelCreatedEventArgs args = new ChannelCreatedEventArgs();
args.Channel = channel;
args.Data = Data;
UnlimitedThreadPool.Run
(
(args2) =>
{
using(args2.Channel)
channeller.ChannelCreated(this, args2);
},
args
);
}
}
private sealed class ClassChannelCreatedSerializer:
ItemSerializer<ClassChannelCreated>
{
internal static readonly ClassChannelCreatedSerializer Instance = new ClassChannelCreatedSerializer();
public override void Serialize(ConfigurableSerializerBase serializer, ClassChannelCreated item)
{
serializer.InnerSerialize(item.Data);
serializer.Stream.WriteCompressedInt32(item.SenderChannelId);
}
public override ClassChannelCreated Deserialize(ConfigurableSerializerBase deserializer)
{
var result = new ClassChannelCreated();
result.Data = deserializer.InnerDeserialize();
result.SenderChannelId = deserializer.Stream.ReadCompressedInt32();
return result;
}
}
private sealed class ChannelRemoved:
IChannelAction
{
internal int ReceiverChannelId;
public void Run(StreamChanneller channeller)
{
StreamChannel channel;
lock(channeller.DisposeLock)
{
if (channeller.WasDisposed)
return;
channeller._channels.TryGetValue(ReceiverChannelId, out channel);
}
if (channel != null)
channel._BeginDispose();
}
}
private sealed class ChannelRemovedSerializer:
ItemSerializer<ChannelRemoved>
{
internal static readonly ChannelRemovedSerializer Instance = new ChannelRemovedSerializer();
public override void Serialize(ConfigurableSerializerBase serializer, ChannelRemoved item)
{
serializer.Stream.WriteCompressedInt32(item.ReceiverChannelId);
}
public override ChannelRemoved Deserialize(ConfigurableSerializerBase deserializer)
{
var result = new ChannelRemoved();
result.ReceiverChannelId = deserializer.Stream.ReadCompressedInt32();
return result;
}
}
private sealed class ChannelAssociated:
IChannelAction
{
internal int ReceiverChannelId;
internal int SenderChannelId;
public void Run(StreamChanneller channeller)
{
StreamChannel channel = null;
lock(channeller.DisposeLock)
{
if (channeller.WasDisposed)
return;
channel = channeller._channels[ReceiverChannelId];
channel._remoteId = SenderChannelId;
channeller._awaitingChannels[channel._id].Set();
}
}
}
private sealed class ChannelAssociatedSerializer:
ItemSerializer<ChannelAssociated>
{
internal static readonly ChannelAssociatedSerializer Instance = new ChannelAssociatedSerializer();
public override void Serialize(ConfigurableSerializerBase serializer, ChannelAssociated item)
{
var stream = serializer.Stream;
stream.WriteCompressedInt32(item.ReceiverChannelId);
stream.WriteCompressedInt32(item.SenderChannelId);
}
public override ChannelAssociated Deserialize(ConfigurableSerializerBase deserializer)
{
var stream = deserializer.Stream;
var result = new ChannelAssociated();
result.ReceiverChannelId = stream.ReadCompressedInt32();
result.SenderChannelId = stream.ReadCompressedInt32();
return result;
}
}
#endregion
#region IChanneller Members
IChannel IChanneller.CreateChannel(object createData)
{
return CreateChannel(createData);
}
#endregion
}
}