using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
using Xunit;
namespace Fadd.Commands.Net
{
/// <summary>
/// Transports commands in binary format.
/// </summary>
public class BinaryChannel : IDisposable
{
private const int BufferSize = 4192;
private readonly BinaryFormatter _formatter = new BinaryFormatter();
private readonly byte[] _inbuffer = new byte[BufferSize];
private readonly Packet _outPacket = new Packet();
private readonly Queue<byte[]> _sendQueue = new Queue<byte[]>();
private Packet _inpacket = new Packet();
private Timer _reconnectTimer;
private IPEndPoint _remoteEndPoint;
private bool _shouldReconnect;
private Socket _socket;
/// <summary>
/// Invoked when we have recieved a packet from the remote end.
/// </summary>
public ObjectReceivedHandler ObjectReceived = delegate { };
/// <summary>
/// Initializes a new instance of the <see cref="BinaryChannel"/> class.
/// </summary>
/// <param name="socket">The socket.</param>
public BinaryChannel(Socket socket)
{
_socket = socket;
_remoteEndPoint = (IPEndPoint) _socket.RemoteEndPoint;
_socket.BeginReceive(_inbuffer, 0, BufferSize, SocketFlags.None, OnReceiveComplete, null);
}
/// <summary>
/// Initializes a new instance of the <see cref="BinaryChannel"/> class.
/// </summary>
public BinaryChannel()
{
}
/// <summary>
/// true if we should reconnect when getting disconnected.
/// </summary>
public bool ShouldReconnect
{
get { return _shouldReconnect; }
set { _shouldReconnect = value; }
}
/// <summary>
/// Invoked when channel is disconnected (except when <see cref="Close"/> are called).
/// </summary>
public event DisconnectedHandler Disconnected = delegate { };
/// <summary>
/// Releases unmanaged resources and performs other cleanup operations before the
/// <see cref="BinaryChannel"/> is reclaimed by garbage collection.
/// </summary>
// //todo: Desctructor is called prematurely and I don't know why (tests fail because of it).
~BinaryChannel()
{
Dispose();
}
/// <summary>
/// Sends the specified value.
/// </summary>
/// <param name="value">object to serialize and send.</param>
public void Send(object value)
{
MemoryStream ms = new MemoryStream();
_formatter.Serialize(ms, value);
ms.Seek(0, SeekOrigin.Begin);
byte[] buffer = new byte[ms.Length];
ms.Read(buffer, 0, (int) ms.Length);
Send(buffer);
}
///<summary>
/// Send a packet to the client.
///</summary>
///<param name="bytes"></param>
public void Send(byte[] bytes)
{
lock (_sendQueue)
{
_sendQueue.Enqueue(bytes);
Send();
}
}
private void Send()
{
if (_socket == null || !_socket.Connected)
return;
lock (_sendQueue)
{
if (_outPacket.buffer != null)
return;
if (_sendQueue.Count == 0)
return;
_outPacket.buffer = _sendQueue.Dequeue();
}
_outPacket.index = 0;
_outPacket.size = _outPacket.buffer.Length;
byte[] header = BitConverter.GetBytes(_outPacket.size);
int bytes = _socket.Send(header);
if (bytes != header.Length)
{
Console.WriteLine("Header was not sent properly.");
HandleDisconnect(SocketError.Success);
}
_socket.BeginSend(_outPacket.buffer, 0, _outPacket.size, SocketFlags.None, OnSendComplete, null);
}
private void OnSendComplete(IAsyncResult ar)
{
try
{
SocketError errorCode;
int bytesSent = _socket.EndSend(ar, out errorCode);
if (bytesSent == 0)
{
HandleDisconnect(SocketError.ConnectionReset);
return;
}
_outPacket.index += bytesSent;
if (_outPacket.index < _outPacket.size)
{
_socket.BeginSend(_outPacket.buffer, _outPacket.index, _outPacket.size, SocketFlags.None,
OnSendComplete, null);
}
else
{
_outPacket.Clear();
Send();
}
}
catch (SocketException err)
{
HandleDisconnect(err.SocketErrorCode);
}
catch (ObjectDisposedException)
{
HandleDisconnect(SocketError.ConnectionReset);
}
}
private void OnReceiveComplete(IAsyncResult ar)
{
try
{
SocketError errorCode;
int bytesRead = _socket.EndReceive(ar, out errorCode);
if (errorCode != SocketError.Success)
{
HandleDisconnect(errorCode);
return;
}
// Loop until all bytes hae been processed.
int index = 0;
while (true)
{
index = ProcessInBuffer(_inbuffer, index, bytesRead);
if (index == bytesRead)
break;
}
_socket.BeginReceive(_inbuffer, 0, BufferSize, SocketFlags.None, OnReceiveComplete, null);
}
catch (ObjectDisposedException)
{
HandleDisconnect(SocketError.ConnectionReset);
}
catch (SocketException err)
{
HandleDisconnect(err.SocketErrorCode);
}
}
/// <summary>
/// Goes through all incoming bytes and creates a packet.
/// One or more calls might be required to get a complete packet.
/// </summary>
/// <param name="inbuffer">buffer to process</param>
/// <param name="index">where to start processing</param>
/// <param name="count">total number of bytes in buffer.</param>
/// <returns>number of bytes that are handled</returns>
private int ProcessInBuffer(byte[] inbuffer, int index, int count)
{
int bytesLeft = count;
// new packet, read header
if (_inpacket.size == 0)
{
if (bytesLeft < 4)
{
Console.WriteLine("Missing packet header.");
HandleDisconnect(SocketError.Success);
return 0;
}
bytesLeft -= 4;
_inpacket = new Packet();
_inpacket.size = BitConverter.ToInt32(inbuffer, 0);
_inpacket.buffer = new byte[_inpacket.size];
index += 4;
}
// copy object bytes.
for (; bytesLeft > 0; ++index, --bytesLeft)
{
_inpacket.buffer[_inpacket.index++] = inbuffer[index];
if (_inpacket.index == _inpacket.size)
break;
}
if (_inpacket.index == _inpacket.size)
{
OnBufferReceived(_inpacket.buffer);
_inpacket.Clear();
}
return index < count ? index + 1 : count;
}
/// <summary>
/// Called when a object buffer have been received completely.
/// </summary>
/// <param name="buffer">The buffer.</param>
protected virtual void OnBufferReceived(byte[] buffer)
{
try
{
MemoryStream ms = new MemoryStream(buffer);
object obj = _formatter.Deserialize(ms);
ObjectReceived(this, new ObjectReceivedEventArgs(obj));
}
catch (SerializationException err)
{
Console.WriteLine(err);
}
}
[Fact]
private void Test2InPackets()
{
byte[] packet1 = TestCreatePacket("hello");
byte[] packet2 = TestCreatePacket("world");
byte[] packet = new byte[packet1.Length + packet2.Length];
packet1.CopyTo(packet, 0);
for (int i = 0; i < packet2.Length; ++i)
packet[i + packet1.Length] = packet2[i];
int objectCount = 0;
ObjectReceived += delegate { ++objectCount; };
int index = 0;
index = ProcessInBuffer(packet, index, packet.Length);
Assert.Equal(packet1.Length, index);
index = ProcessInBuffer(packet, index, packet.Length - index);
Assert.Equal(packet2.Length, index);
Assert.Equal(2, objectCount);
}
byte[] TestCreatePacket(string text)
{
MemoryStream ms = new MemoryStream();
BinaryFormatter formatter = new BinaryFormatter();
formatter.Serialize(ms, text);
ms.Seek(0, SeekOrigin.Begin);
byte[] data = new byte[ms.Length];
ms.Read(data, 0, (int) ms.Length);
byte[] packet = new byte[data.Length + 4];
BitConverter.GetBytes(data.Length).CopyTo(packet, 0);
for (int i = 0; i < data.Length; ++i)
packet[i + 4] = data[i];
return packet;
}
[Fact]
private void TestPartialInpacket()
{
byte[] all = TestCreatePacket("hello");
int objectCount = 0;
ObjectReceived += delegate { ++objectCount; };
int index = ProcessInBuffer(all, 0, all.Length - 3);
Assert.Equal(index, all.Length - 3);
index = ProcessInBuffer(all, all.Length - 3, 3);
Assert.Equal(index, 3);
Assert.Equal(1, objectCount);
}
[Fact]
private void TestSecondPartialInpacket()
{
BinaryChannel channel = new BinaryChannel();
byte[] first = TestCreatePacket("hello");
byte[] second = TestCreatePacket("world");
byte[] all = new byte[first.Length + second.Length];
for (int i = 0; i < first.Length; ++i)
all[i] = first[i];
for (int i = 0; i < second.Length; ++i)
all[i + first.Length] = second[i];
int objectCount = 0;
channel.ObjectReceived += delegate
{
++objectCount;
};
int index = channel.ProcessInBuffer(all, 0, first.Length + 4);
Assert.Equal(first.Length, index);
index = channel.ProcessInBuffer(all, index, all.Length - 5);
Assert.Equal(all.Length - 5, index);
index = channel.ProcessInBuffer(all, index, 5);
Assert.Equal(5, index);
Thread.Sleep(500);
Assert.Equal(2, objectCount);
}
private void HandleDisconnect(SocketError code)
{
if (_shouldReconnect)
{
if (_reconnectTimer != null)
return;
_reconnectTimer = new Timer(TryConnect, null, 15000, 15000);
}
_socket.Disconnect(true);
_socket = null;
Disconnected(this, new DisconnectedEventArgs(code));
}
private void TryConnect(object state)
{
try
{
_socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_socket.Connect(_remoteEndPoint);
Timer timer = _reconnectTimer;
_reconnectTimer = null;
timer.Dispose();
}
catch (SocketException)
{
}
}
/// <summary>
/// Closes this instance.
/// </summary>
public void Close()
{
_shouldReconnect = false;
_socket.Close();
_socket = null;
}
/// <summary>
/// Connect to an endpoint.
/// </summary>
/// <param name="endPoint">Where to connect</param>
/// <exception cref="SocketException">if connection fails.</exception>
public void Open(IPEndPoint endPoint)
{
if (_socket == null)
_socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_remoteEndPoint = endPoint;
_socket.Connect(endPoint);
_socket.BeginReceive(_inbuffer, 0, BufferSize, SocketFlags.None, OnReceiveComplete, null);
}
#region Implementation of IDisposable
/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// </summary>
/// <filterpriority>2</filterpriority>
public void Dispose()
{
_shouldReconnect = false;
try
{
if (_socket != null)
{
_socket.Close();
_socket = null;
}
}
catch (SocketException)
{
}
if (_reconnectTimer != null)
{
_reconnectTimer.Dispose();
_reconnectTimer = null;
}
_sendQueue.Clear();
Disconnected = null;
}
#endregion
#region Nested type: Packet
private class Packet
{
public byte[] buffer;
public int index;
public int size;
public void Clear()
{
size = 0;
index = 0;
buffer = null;
}
}
#endregion
}
/// <summary>
/// Event args for <see cref="ObjectReceivedHandler"/>
/// </summary>
public class ObjectReceivedEventArgs : EventArgs
{
private readonly object _object;
/// <summary>
/// Initializes a new instance of the <see cref="ObjectReceivedEventArgs"/> class.
/// </summary>
/// <param name="value">object received from remote end.</param>
public ObjectReceivedEventArgs(object value)
{
Check.Require(value, "value");
_object = value;
}
/// <summary>
/// Bytes received from remote end.
/// </summary>
public object Object
{
get { return _object; }
}
}
/// <summary>
/// Invoked when an object have been received from the remote end.
/// </summary>
/// <param name="source">Client that received the object.</param>
/// <param name="args">object received.</param>
public delegate void ObjectReceivedHandler(object source, ObjectReceivedEventArgs args);
}