using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
using Pfz.Caching;
using Pfz.Extensions.DisposeExtensions;
using Pfz.Extensions.MonitorLockExtensions;
using Pfz.Threading;
namespace Pfz.Remoting
{
/// <summary>
/// Class responsible for creating many channels inside another stream.
/// This is used by the remoting framework, so each thread has it's own
/// channel inside a single tcp/ip connection.
/// </summary>
public sealed class StreamChanneller:
ThreadSafeExceptionAwareDisposable,
IChanneller
{
#region Private and internal fields
private Stream fStream;
private object fChannelsLock = new object();
private volatile Dictionary<int, Channel> fChannels = new Dictionary<int, Channel>();
private object fAwaitingChannelsLock = new object();
private volatile Dictionary<int, ManualResetEvent> fAwaitingChannels = new Dictionary<int, ManualResetEvent>();
private Channel fMainChannel;
internal int fChannelBufferSize;
internal Queue<byte[]> fBuffersToSend = new Queue<byte[]>();
private int fNextChannelId;
private bool fCanThrow;
internal ActionRunner<KeyValuePair<int, int>> fRunRemoveChannel = new ActionRunner<KeyValuePair<int, int>>();
#endregion
#region Constructors
/// <summary>
/// Creates the channeller for the specified stream.
/// </summary>
/// <param name="stream">The stream to channel.</param>
/// <param name="remoteChannelCreated">
/// Handler to invoke when a channel is created as a request from the other side.
/// </param>
public StreamChanneller(Stream stream, EventHandler<ChannelCreatedEventArgs> remoteChannelCreated):
this(stream, remoteChannelCreated, 8 * 1024)
{
}
/// <summary>
/// Creates the channeller for the specified stream and allows you to
/// specify the buffer size. For tcp/ip stream, use the bigger value
/// between receive an send buffer size.
/// </summary>
/// <param name="stream">The stream to channel.</param>
/// <param name="remoteChannelCreated">
/// Handler to invoke when a channel is created as a request from the other side.
/// </param>
/// <param name="bufferSizePerChannel">The buffer size used when receiving and sending to each channel.</param>
public StreamChanneller(Stream stream, EventHandler<ChannelCreatedEventArgs> remoteChannelCreated, int bufferSizePerChannel):
this(stream, remoteChannelCreated, bufferSizePerChannel, true)
{
}
/// <summary>
/// Creates the channeller for the specified stream and allows you to
/// specify the buffer size. For tcp/ip stream, use the bigger value
/// between receive an send buffer size.
/// </summary>
/// <param name="stream">The stream to channel.</param>
/// <param name="remoteChannelCreated">
/// Handler to invoke when a channel is created as a request from the other side.
/// It is invoked in a separate exclusive thread. You don't need to create one.
/// </param>
/// <param name="canThrow">
/// If true (the default value) can throw exception while reading.
/// If false, only disposes the object but does not throw an exception.
/// </param>
/// <param name="bufferSizePerChannel">The buffer size used when receiving and sending to each channel.</param>
public StreamChanneller(Stream stream, EventHandler<ChannelCreatedEventArgs> remoteChannelCreated, int bufferSizePerChannel, bool canThrow)
{
if (stream == null)
throw new ArgumentNullException("stream");
if (remoteChannelCreated == null)
throw new ArgumentNullException("remoteChannelCreated");
if (bufferSizePerChannel < 256)
throw new ArgumentException("bufferSizePerChannel can't be less than 256 bytes", "bufferSizePerChannel");
fCanThrow = canThrow;
fChannelBufferSize = bufferSizePerChannel;
RemoteChannelCreated = remoteChannelCreated;
Channel mainChannel = new Channel(this);
fMainChannel = mainChannel;
fChannels.Add(0, mainChannel);
fStream = stream;
Thread threadReader = new Thread(p_Reader);
threadReader.IsBackground = true;
threadReader.Name = "StreamChanneller reader.";
threadReader.Start();
Thread threadWriter = new Thread(p_Writer);
threadWriter.IsBackground = true;
threadWriter.Name = "StreamChanneller writer.";
threadWriter.Start();
Thread threadMainChannel = new Thread(p_MainChannel);
threadMainChannel.IsBackground = true;
threadMainChannel.Name = "StreamChanneller main channel.";
threadMainChannel.Start(mainChannel);
GCUtils.Collected += p_Collected;
}
#endregion
#region Dispose
/// <summary>
/// Disposes the channeller and the stream.
/// </summary>
/// <param name="disposing">true if called from Dispose() and false if called from destructor.</param>
protected override void Dispose(bool disposing)
{
if (disposing)
{
GCUtils.Collected -= p_Collected;
var stream = fStream;
if (stream != null)
{
fStream = null;
stream.Dispose();
}
var runRemoveChannel = fRunRemoveChannel;
if (runRemoveChannel != null)
{
fRunRemoveChannel = null;
runRemoveChannel.Dispose();
}
var channels = fChannels;
if (channels != null)
{
fChannels = null;
AbortSafe.Lock
(
channels,
delegate
{
foreach(Channel channel in channels.Values)
channel.Dispose(DisposeException);
}
);
}
Dictionary<int, ManualResetEvent> awaitingChannels = null;
AbortSafe.UnabortableLock
(
fAwaitingChannelsLock,
delegate
{
awaitingChannels = fAwaitingChannels;
fAwaitingChannels = null;
}
);
if (awaitingChannels != null)
foreach(ManualResetEvent mre in awaitingChannels.Values)
mre.Set();
var writerEvent = fWriterEvent;
if (writerEvent != null)
{
try
{
writerEvent.Set();
}
catch
{
}
}
}
base.Dispose(disposing);
if (disposing)
if (Disposed != null)
Disposed(this, EventArgs.Empty);
}
#endregion
#region p_Collected
private void p_Collected()
{
try
{
if (WasDisposed)
{
GCUtils.Collected -= p_Collected;
return;
}
p_CollectAwaitingChannels();
p_CollectChannels();
p_CollectBuffersToSend();
}
catch
{
// ignore any exceptions, as the lists are kept intact if there
// is no memory.
}
}
private void p_CollectAwaitingChannels()
{
AbortSafe.Lock
(
fAwaitingChannelsLock,
() => fAwaitingChannels = new Dictionary<int, ManualResetEvent>(fAwaitingChannels)
);
}
private void p_CollectChannels()
{
AbortSafe.Lock
(
fChannelsLock,
() => fChannels = new Dictionary<int, Channel>(fChannels)
);
}
private void p_CollectBuffersToSend()
{
var buffersToSend = fBuffersToSend;
AbortSafe.UnabortableLock
(
buffersToSend,
() => buffersToSend.TrimExcess()
);
}
#endregion
#region Methods
#region CreateChannel
/// <summary>
/// Creates a new channel.
/// </summary>
/// <returns>Returns a new channel inside the stream.</returns>
public Channel CreateChannel()
{
return CreateChannel(null);
}
/// <summary>
/// Creates a channel, sending the serializableData parameter to the
/// other side, so it can decide what to do with this channel before it
/// gets used (this avoids an extra tcp/ip packet for small information).
/// </summary>
/// <param name="serializableData">Data to send to the other side.</param>
/// <returns>A new channel.</returns>
public Channel CreateChannel(object serializableData)
{
try
{
int channelId = Interlocked.Increment(ref fNextChannelId);
Channel channel = new Channel(this);
channel.fId = channelId;
fChannelsLock.LockWithTimeout
(
() => fChannels.Add(channelId, channel)
);
ChannelCreated channelCreated = new ChannelCreated();
channelCreated.SenderChannelId = channelId;
channelCreated.Data = serializableData;
ManualResetEvent manualResetEvent = null;
try
{
AbortSafe.Run(()=> manualResetEvent = new ManualResetEvent(false));
fAwaitingChannelsLock.LockWithTimeout
(
() => fAwaitingChannels.Add(channelId, manualResetEvent)
);
try
{
BinaryFormatter binaryFormatter = new BinaryFormatter();
var mainChannel = fMainChannel;
mainChannel.LockWithTimeout
(
() => binaryFormatter.Serialize(mainChannel, channelCreated)
);
manualResetEvent.WaitOne();
CheckUndisposed();
}
finally
{
fAwaitingChannelsLock.LockWithTimeout
(
() => fAwaitingChannels.Remove(channelId)
);
}
}
finally
{
manualResetEvent.CheckedDispose();
}
return channel;
}
catch(Exception exception)
{
if (!WasDisposed)
Dispose(exception);
throw;
}
}
#endregion
#region i_RemoveChannel
internal void i_RemoveChannel(KeyValuePair<int, int> pair)
{
try
{
int id = pair.Key;
int remoteId = pair.Value;
bool mustReturn = true;
fChannelsLock.LockWithTimeout
(
delegate
{
var channels = fChannels;
if (channels == null)
return;
channels.Remove(id);
mustReturn = false;
}
);
if (mustReturn)
return;
BinaryFormatter binaryFormatter = new BinaryFormatter();
ChannelRemoved channelRemoved = new ChannelRemoved();
channelRemoved.ReceiverChannelId = remoteId;
var mainChannel = fMainChannel;
mainChannel.LockWithTimeout
(
() => binaryFormatter.Serialize(mainChannel, channelRemoved)
);
}
catch
{
}
}
#endregion
#region p_Reader
private void p_Reader()
{
try
{
byte[] headerBuffer = new byte[8];
while(true)
{
p_SetReadTimeOut(Timeout.Infinite);
p_Read(headerBuffer, 8);
int channelId = BitConverter.ToInt32(headerBuffer, 0);
int messageSize = BitConverter.ToInt32(headerBuffer, 4);
Channel channel = null;
fChannelsLock.LockWithTimeout
(
() => fChannels.TryGetValue(channelId, out channel)
);
p_SetReadTimeOut(60000);
if (channel == null)
{
p_Discard(messageSize);
continue;
}
int bytesLeft = messageSize;
while (bytesLeft > 0)
{
if (WasDisposed)
break;
int count = bytesLeft;
if (bytesLeft > fChannelBufferSize)
count = fChannelBufferSize;
byte[] messageBuffer;
try
{
messageBuffer = new byte[count];
}
catch(Exception exception)
{
channel.Dispose(exception);
channel.fInMessages = null;
continue;
}
p_Read(messageBuffer, count);
bytesLeft -= count;
var channelMessages = channel.fInMessages;
channelMessages.LockWithTimeout
(
delegate
{
try
{
channelMessages.Enqueue(messageBuffer);
}
catch(Exception exception)
{
channel.Dispose(exception);
channel.fInMessages = null;
return;
}
var waitEvent = channel.fWaitEvent;
if (waitEvent != null)
waitEvent.Set();
}
);
}
}
}
catch(Exception exception)
{
if (!WasDisposed)
Dispose(exception);
}
}
#endregion
#region p_Read
private void p_Read(byte[] buffer, int count)
{
int totalRead = 0;
while(totalRead < count)
{
int read = fStream.Read(buffer, totalRead, count-totalRead);
if (read == 0)
{
var exception = new RemotingException("Stream closed.");
Dispose(exception);
throw exception;
}
totalRead += read;
}
}
#endregion
#region p_Discard
private void p_Discard(int bytesToDiscard)
{
int bufferSize = Math.Min(bytesToDiscard, fChannelBufferSize);
byte[] discardBuffer = new byte[bufferSize];
int bytesLeft = bytesToDiscard;
while(bytesLeft > 0)
{
if (bytesLeft < bufferSize)
{
p_Read(discardBuffer, bytesLeft);
break;
}
p_Read(discardBuffer, bufferSize);
bytesLeft -= bufferSize;
}
}
#endregion
#region p_Writer
internal ManualResetEvent fWriterEvent = new ManualResetEvent(false);
private void p_Writer()
{
var writerEvent = fWriterEvent;
try
{
try
{
var buffersToSend = fBuffersToSend;
while(true)
{
p_SetWriteTimeOut(Timeout.Infinite);
writerEvent.WaitOne();
if (WasDisposed)
{
fWriterEvent = null;
return;
}
writerEvent.Reset();
p_SetWriteTimeOut(60000);
while(true)
{
bool mustBreak = false;
bool mustReturn = false;
byte[] buffer = null;
AbortSafe.UnabortableLock
(
buffersToSend,
delegate
{
if (buffersToSend.Count == 0)
{
mustBreak = true;
return;
}
if (WasDisposed)
{
mustReturn = true;
return;
}
buffer = buffersToSend.Dequeue();
}
);
if (mustBreak)
break;
if (mustReturn)
return;
fStream.Write(buffer, 0, buffer.Length);
}
fStream.Flush();
}
}
catch(Exception exception)
{
if (!WasDisposed)
{
Dispose(exception);
if (fCanThrow)
throw;
}
}
}
finally
{
fWriterEvent = null;
writerEvent.Close();
}
}
#endregion
#region p_MainChannel
private void p_MainChannel(object mainChannelAsObject)
{
Channel mainChannel = (Channel)mainChannelAsObject;
try
{
BinaryFormatter binaryFormatter = new BinaryFormatter();
while(true)
{
object obj = binaryFormatter.Deserialize(mainChannel);
ChannelCreated channelCreated = obj as ChannelCreated;
if (channelCreated != null)
{
int localChannelId = Interlocked.Increment(ref fNextChannelId);
Channel channel = new Channel(this);
channel.fId = localChannelId;
channel.fRemoteId = channelCreated.SenderChannelId;
fChannelsLock.LockWithTimeout
(
() => fChannels.Add(localChannelId, channel)
);
ChannelAssociated associated = new ChannelAssociated();
associated.SenderChannelId = localChannelId;
associated.ReceiverChannelId = channelCreated.SenderChannelId;
mainChannel.LockWithTimeout
(
() => binaryFormatter.Serialize(mainChannel, associated)
);
ChannelCreatedEventArgs args = new ChannelCreatedEventArgs();
args.Channel = channel;
args.Data = channelCreated.Data;
UnlimitedThreadPool.Run
(
() =>
{
Exception exception = null;
try
{
RemoteChannelCreated(this, args);
}
catch(Exception caughtException)
{
exception = caughtException;
}
finally
{
if (args.CanDisposeChannel)
args.Channel.Dispose(exception);
}
}
);
}
else
{
ChannelRemoved channelRemoved = obj as ChannelRemoved;
if (channelRemoved != null)
{
Channel channel = null;
fChannelsLock.LockWithTimeout
(
() => fChannels.TryGetValue(channelRemoved.ReceiverChannelId, out channel)
);
if (channel != null)
channel.Dispose();
}
else
{
ChannelAssociated channelAssociated = (ChannelAssociated)obj;
Channel channel = null;
fChannelsLock.LockWithTimeout
(
() => channel = fChannels[channelAssociated.ReceiverChannelId]
);
channel.fRemoteId = channelAssociated.SenderChannelId;
fAwaitingChannelsLock.LockWithTimeout
(
() => fAwaitingChannels[channel.fId].Set()
);
}
}
}
}
catch(Exception exception)
{
if (!WasDisposed)
{
Dispose(exception);
if (fCanThrow)
throw;
}
}
}
#endregion
#region p_SetReadTimeOut
private void p_SetReadTimeOut(int timeout)
{
if (fStream.CanTimeout)
fStream.ReadTimeout = timeout;
}
#endregion
#region p_SetWriteTimeOut
private void p_SetWriteTimeOut(int timeout)
{
if (fStream.CanTimeout)
fStream.WriteTimeout = timeout;
}
#endregion
#endregion
#region Events
/// <summary>
/// Event called when Dispose() has just finished.
/// </summary>
public event EventHandler Disposed;
/// <summary>
/// Event that is invoked when the remote side creates a new channel.
/// </summary>
public event EventHandler<ChannelCreatedEventArgs> RemoteChannelCreated;
#endregion
#region Nested classes
[Serializable]
private sealed class ChannelCreated
{
internal int SenderChannelId;
internal object Data;
}
[Serializable]
private sealed class ChannelRemoved
{
internal int ReceiverChannelId;
}
[Serializable]
private sealed class ChannelAssociated
{
internal int ReceiverChannelId;
internal int SenderChannelId;
}
#endregion
#region IChanneller Members
ExceptionAwareStream IChanneller.CreateChannel()
{
return CreateChannel();
}
ExceptionAwareStream IChanneller.CreateChannel(object createData)
{
return CreateChannel(createData);
}
#endregion
}
}