using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
using Pfz.Caching;
using Pfz.Extensions.MonitorLockExtensions;
using Pfz.Threading;
namespace Pfz.Remoting
{
/// <summary>
/// Class responsible for creating many channels inside another stream.
/// This is used by the remoting framework, so each thread has it's own
/// channel inside a single tcp/ip connection.
/// </summary>
public sealed class StreamChanneller:
ThreadSafeDisposable
{
#region Private and internal fields
private Stream fStream;
private object fChannelsLock = new object();
private volatile Dictionary<int, Channel> fChannels = new Dictionary<int, Channel>();
private object fAwaitingChannelsLock = new object();
private volatile Dictionary<int, ManualResetEvent> fAwaitingChannels = new Dictionary<int, ManualResetEvent>();
private Channel fMainChannel;
internal int fChannelBufferSize;
internal Queue<byte[]> fBuffersToSend = new Queue<byte[]>();
private int fNextChannelId;
private EventHandler<ChannelCreatedEventArgs> fRemoteChannelCreated;
private bool fCanThrow;
internal ActionRunner<KeyValuePair<int, int>> fRunRemoveChannel = new ActionRunner<KeyValuePair<int, int>>();
#endregion
#region Constructors
/// <summary>
/// Creates the channeller for the specified stream.
/// </summary>
/// <param name="stream">The stream to channel.</param>
/// <param name="remoteChannelCreated">
/// Handler to invoke when a channel is created as a request from the other side.
/// </param>
public StreamChanneller(Stream stream, EventHandler<ChannelCreatedEventArgs> remoteChannelCreated):
this(stream, remoteChannelCreated, 32 * 1024, null)
{
}
/// <summary>
/// Creates the channeller for the specified stream and allows you to
/// specify the buffer size. For tcp/ip stream, use the bigger value
/// between receive an send buffer size.
/// </summary>
/// <param name="stream">The stream to channel.</param>
/// <param name="remoteChannelCreated">
/// Handler to invoke when a channel is created as a request from the other side.
/// </param>
/// <param name="bufferSizePerChannel">The buffer size used when receiving and sending to each channel.</param>
/// <param name="timeOutParameters">The object used to set the time-outs for the stream.</param>
public StreamChanneller(Stream stream, EventHandler<ChannelCreatedEventArgs> remoteChannelCreated, int bufferSizePerChannel, ITimeOutParameters timeOutParameters):
this(stream, remoteChannelCreated, bufferSizePerChannel, true, timeOutParameters)
{
}
/// <summary>
/// Creates the channeller for the specified stream and allows you to
/// specify the buffer size. For tcp/ip stream, use the bigger value
/// between receive an send buffer size.
/// </summary>
/// <param name="stream">The stream to channel.</param>
/// <param name="remoteChannelCreated">
/// Handler to invoke when a channel is created as a request from the other side.
/// It is invoked in a separate exclusive thread. You don't need to create one.
/// </param>
/// <param name="canThrow">
/// If true (the default value) can throw exception while reading.
/// If false, only disposes the object but does not throw an exception.
/// </param>
/// <param name="bufferSizePerChannel">The buffer size used when receiving and sending to each channel.</param>
/// <param name="timeOutParameters">The object used to set the timeouts of the stream.</param>
public StreamChanneller(Stream stream, EventHandler<ChannelCreatedEventArgs> remoteChannelCreated, int bufferSizePerChannel, bool canThrow, ITimeOutParameters timeOutParameters)
{
fCanThrow = canThrow;
if (stream == null)
throw new ArgumentNullException("stream");
if (remoteChannelCreated == null)
throw new ArgumentNullException("remoteChannelCreated");
if (bufferSizePerChannel < 256)
throw new ArgumentException("bufferSizePerChannel can't be less than 256 bytes", "bufferSizePerChannel");
TimeOutParameters = timeOutParameters;
fChannelBufferSize = bufferSizePerChannel;
fRemoteChannelCreated = remoteChannelCreated;
Channel mainChannel = new Channel(this);
fMainChannel = mainChannel;
fChannels.Add(0, mainChannel);
fStream = stream;
Thread threadReader = new Thread(p_Reader);
threadReader.IsBackground = true;
threadReader.Name = "StreamChanneller reader.";
threadReader.Start();
Thread threadWriter = new Thread(p_Writer);
threadWriter.IsBackground = true;
threadWriter.Name = "StreamChanneller writer.";
threadWriter.Start();
Thread threadMainChannel = new Thread(p_MainChannel);
threadMainChannel.IsBackground = true;
threadMainChannel.Name = "StreamChanneller main channel.";
threadMainChannel.Start(mainChannel);
GCUtils.Collected += p_Collected;
}
#endregion
#region Dispose
/// <summary>
/// Disposes the channeller and the stream.
/// </summary>
/// <param name="disposing">true if called from Dispose() and false if called from destructor.</param>
protected override void Dispose(bool disposing)
{
if (disposing)
{
GCUtils.Collected -= p_Collected;
var stream = fStream;
if (stream != null)
{
fStream = null;
stream.Dispose();
}
var runRemoveChannel = fRunRemoveChannel;
if (runRemoveChannel != null)
{
fRunRemoveChannel = null;
runRemoveChannel.Dispose();
}
var channels = fChannels;
if (channels != null)
{
fChannels = null;
using(channels.LockWithTimeout())
foreach(Channel channel in channels.Values)
channel.Dispose();
}
Dictionary<int, ManualResetEvent> awaitingChannels;
using(fAwaitingChannelsLock.LockWithTimeout())
{
awaitingChannels = fAwaitingChannels;
fAwaitingChannels = null;
}
if (awaitingChannels != null)
foreach(ManualResetEvent mre in awaitingChannels.Values)
mre.Set();
var writerEvent = fWriterEvent;
if (writerEvent != null)
{
try
{
writerEvent.Set();
}
catch
{
}
}
}
base.Dispose(disposing);
if (disposing)
if (Disposed != null)
Disposed(this, EventArgs.Empty);
}
#endregion
#region p_Collected
private void p_Collected()
{
try
{
if (WasDisposed)
{
GCUtils.Collected -= p_Collected;
return;
}
p_CollectAwaitingChannels();
p_CollectChannels();
p_CollectBuffersToSend();
}
catch
{
// ignore any exceptions, as the lists are kept intact if there
// is no memory.
}
}
private void p_CollectAwaitingChannels()
{
lock (fAwaitingChannelsLock)
fAwaitingChannels = new Dictionary<int, ManualResetEvent>(fAwaitingChannels);
}
private void p_CollectChannels()
{
lock (fChannelsLock)
fChannels = new Dictionary<int, Channel>(fChannels);
}
private void p_CollectBuffersToSend()
{
var buffersToSend = fBuffersToSend;
lock(buffersToSend)
buffersToSend.TrimExcess();
}
#endregion
#region Properties
/// <summary>
/// Gets the TimeOutParameters object utilized by this Channeller.
/// </summary>
public ITimeOutParameters TimeOutParameters { get; private set; }
#endregion
#region Methods
#region CreateChannel
/// <summary>
/// Creates a new channel.
/// </summary>
/// <returns>Returns a new channel inside the stream.</returns>
public Channel CreateChannel()
{
return CreateChannel(null);
}
/// <summary>
/// Creates a channel, sending the serializableData parameter to the
/// other side, so it can decide what to do with this channel before it
/// gets used (this avoids an extra tcp/ip packet for small information).
/// </summary>
/// <param name="serializableData">Data to send to the other side.</param>
/// <returns>A new channel.</returns>
public Channel CreateChannel(object serializableData)
{
try
{
int channelId = Interlocked.Increment(ref fNextChannelId);
Channel channel = new Channel(this);
channel.fId = channelId;
using(fChannelsLock.LockWithTimeout())
fChannels.Add(channelId, channel);
ChannelCreated channelCreated = new ChannelCreated();
channelCreated.SenderChannelId = channelId;
channelCreated.Data = serializableData;
using (ManualResetEvent manualResetEvent = new ManualResetEvent(false))
{
using(fAwaitingChannelsLock.LockWithTimeout())
fAwaitingChannels.Add(channelId, manualResetEvent);
try
{
BinaryFormatter binaryFormatter = new BinaryFormatter();
var mainChannel = fMainChannel;
using(mainChannel.LockWithTimeout())
binaryFormatter.Serialize(mainChannel, channelCreated);
manualResetEvent.WaitOne();
CheckUndisposed();
}
finally
{
using(fAwaitingChannelsLock.LockWithTimeout())
fAwaitingChannels.Remove(channelId);
}
}
return channel;
}
catch
{
if (!WasDisposed)
Dispose();
throw;
}
}
#endregion
#region i_RemoveChannel
internal void i_RemoveChannel(KeyValuePair<int, int> pair)
{
try
{
int id = pair.Key;
int remoteId = pair.Value;
using(fChannelsLock.LockWithTimeout())
{
var channels = fChannels;
if (channels == null)
return;
channels.Remove(id);
}
BinaryFormatter binaryFormatter = new BinaryFormatter();
ChannelRemoved channelRemoved = new ChannelRemoved();
channelRemoved.ReceiverChannelId = remoteId;
var mainChannel = fMainChannel;
using(mainChannel.LockWithTimeout())
binaryFormatter.Serialize(mainChannel, channelRemoved);
}
catch
{
}
}
#endregion
#region p_Reader
private void p_Reader()
{
try
{
byte[] headerBuffer = new byte[8];
while(true)
{
p_SetReadTimeOut(0);
p_Read(headerBuffer, 8);
int channelId = BitConverter.ToInt32(headerBuffer, 0);
int messageSize = BitConverter.ToInt32(headerBuffer, 4);
Channel channel;
using(fChannelsLock.LockWithTimeout())
fChannels.TryGetValue(channelId, out channel);
p_SetReadTimeOut(60000);
if (channel == null)
{
p_Discard(messageSize);
continue;
}
int bytesLeft = messageSize;
while (bytesLeft > 0)
{
if (WasDisposed)
break;
int count = bytesLeft;
if (bytesLeft > fChannelBufferSize)
count = fChannelBufferSize;
byte[] messageBuffer;
try
{
messageBuffer = new byte[count];
}
catch
{
channel.Dispose();
channel.fInMessages = null;
continue;
}
p_Read(messageBuffer, count);
bytesLeft -= count;
var channelMessages = channel.fInMessages;
using(channelMessages.LockWithTimeout())
{
try
{
channelMessages.Enqueue(messageBuffer);
}
catch
{
channel.Dispose();
channel.fInMessages = null;
continue;
}
var waitEvent = channel.fWaitEvent;
if (waitEvent != null)
waitEvent.Set();
}
}
}
}
catch
{
if (!WasDisposed)
Dispose();
}
}
#endregion
#region p_Read
private void p_Read(byte[] buffer, int count)
{
int totalRead = 0;
while(totalRead < count)
{
int read = fStream.Read(buffer, totalRead, count-totalRead);
if (read == 0)
{
Dispose();
throw new RemotingException("Stream closed.");
}
totalRead += read;
}
}
#endregion
#region p_Discard
private void p_Discard(int bytesToDiscard)
{
int bufferSize = Math.Min(bytesToDiscard, fChannelBufferSize);
byte[] discardBuffer = new byte[bufferSize];
int bytesLeft = bytesToDiscard;
while(bytesLeft > 0)
{
if (bytesLeft < bufferSize)
{
p_Read(discardBuffer, bytesLeft);
break;
}
p_Read(discardBuffer, bufferSize);
bytesLeft -= bufferSize;
}
}
#endregion
#region p_Writer
internal ManualResetEvent fWriterEvent = new ManualResetEvent(false);
private void p_Writer()
{
var writerEvent = fWriterEvent;
try
{
try
{
var buffersToSend = fBuffersToSend;
while(true)
{
p_SetWriteTimeOut(0);
writerEvent.WaitOne();
if (WasDisposed)
{
fWriterEvent = null;
return;
}
writerEvent.Reset();
p_SetWriteTimeOut(60000);
while(true)
{
byte[] buffer;
lock(buffersToSend)
{
if (buffersToSend.Count == 0)
break;
if (WasDisposed)
return;
buffer = buffersToSend.Dequeue();
}
fStream.Write(buffer, 0, buffer.Length);
}
fStream.Flush();
}
}
catch
{
if (!WasDisposed)
{
Dispose();
if (fCanThrow)
throw;
}
}
}
finally
{
fWriterEvent = null;
writerEvent.Close();
}
}
#endregion
#region p_MainChannel
private void p_MainChannel(object mainChannelAsObject)
{
Channel mainChannel = (Channel)mainChannelAsObject;
try
{
BinaryFormatter binaryFormatter = new BinaryFormatter();
while(true)
{
object obj = binaryFormatter.Deserialize(mainChannel);
ChannelCreated channelCreated = obj as ChannelCreated;
if (channelCreated != null)
{
int localChannelId = Interlocked.Increment(ref fNextChannelId);
Channel channel = new Channel(this);
channel.fId = localChannelId;
channel.fRemoteId = channelCreated.SenderChannelId;
using(fChannelsLock.LockWithTimeout())
fChannels.Add(localChannelId, channel);
ChannelAssociated associated = new ChannelAssociated();
associated.SenderChannelId = localChannelId;
associated.ReceiverChannelId = channelCreated.SenderChannelId;
using(mainChannel.LockWithTimeout())
binaryFormatter.Serialize(mainChannel, associated);
ChannelCreatedEventArgs args = new ChannelCreatedEventArgs();
args.Channel = channel;
args.Data = channelCreated.Data;
args.CanDisposeChannel = true;
UnlimitedThreadPool.Run
(
() =>
{
try
{
fRemoteChannelCreated(this, args);
}
finally
{
if (args.CanDisposeChannel)
args.Channel.Dispose();
}
}
);
}
else
{
ChannelRemoved channelRemoved = obj as ChannelRemoved;
if (channelRemoved != null)
{
Channel channel;
using(fChannelsLock.LockWithTimeout())
fChannels.TryGetValue(channelRemoved.ReceiverChannelId, out channel);
if (channel != null)
channel.Dispose();
}
else
{
ChannelAssociated channelAssociated = (ChannelAssociated)obj;
Channel channel;
using(fChannelsLock.LockWithTimeout())
channel = fChannels[channelAssociated.ReceiverChannelId];
channel.fRemoteId = channelAssociated.SenderChannelId;
using(fAwaitingChannelsLock.LockWithTimeout())
fAwaitingChannels[channel.fId].Set();
}
}
}
}
catch
{
if (!WasDisposed)
{
Dispose();
if (fCanThrow)
throw;
}
}
}
#endregion
#region p_SetReadTimeOut
private void p_SetReadTimeOut(int timeout)
{
var timeOutParameters = TimeOutParameters;
if (timeOutParameters != null)
timeOutParameters.ReadTimeOut = timeout;
}
#endregion
#region p_SetWriteTimeOut
private void p_SetWriteTimeOut(int timeout)
{
var timeOutParameters = TimeOutParameters;
if (timeOutParameters != null)
timeOutParameters.WriteTimeOut = timeout;
}
#endregion
#endregion
#region Events
/// <summary>
/// Event called when Dispose() has just finished.
/// </summary>
public event EventHandler Disposed;
#endregion
#region Nested classes
[Serializable]
private sealed class ChannelCreated
{
internal int SenderChannelId;
internal object Data;
}
[Serializable]
private sealed class ChannelRemoved
{
internal int ReceiverChannelId;
}
[Serializable]
private sealed class ChannelAssociated
{
internal int ReceiverChannelId;
internal int SenderChannelId;
}
#endregion
}
}