Click here to Skip to main content
15,897,334 members
Articles / Web Development / HTML

A C++ Websocket server for realtime interaction with Web clients

Rate me:
Please Sign up or sign in to vote.
4.81/5 (23 votes)
17 May 2012Apache5 min read 179.5K   22.3K   105  
A Websocket protocol implementation atop the ush Framework real time library plus a demo example featuring four types of communication workflows between the HTML5 web client and the server.
// WebsocketProtocol.cpp : Defines the exported functions for the DLL application.
//

#include "stdafx.h"
#include "WebsocketProtocol.h"
#include "base64/base64.h"
#include "sha1/sha1.h"


#include "WebsocketHandshakeMessage.h"
#include "WebsocketDataMessage.h"
#include "WebsocketConnectionContext.h"




// This is the constructor of a class that has been exported.
// see WebsocketProtocol.h for the class definition
WebsocketProtocol::WebsocketProtocol()
{
	return;
}

int WebsocketProtocol::tryDeframeIncomingPacket( PushFramework::DataBuffer& buffer, PushFramework::IncomingPacket*& pPacket, int& serviceId, unsigned int& nExtractedBytes, ConnectionContext* pContext )
{
    if (buffer.GetDataSize() == 0)
        return Protocol::eIncompletePacket;

	WebsocketConnectionContext* pCxt = (WebsocketConnectionContext*) pContext;

	if (pCxt->GetStage() == WebsocketConnectionContext::HandshakeStage)
	{
		WebsocketHandshakeMessage* pMessage = new WebsocketHandshakeMessage(buffer.GetBuffer(), buffer.GetDataSize());
		serviceId = 0;
		nExtractedBytes = buffer.GetDataSize();
		pPacket = pMessage;
		return Protocol::Success;
	}

	//In the other cases, we should expect a data message : 
    int nMinExpectedSize = 6;
    if (buffer.GetDataSize() < nMinExpectedSize)
        return Protocol::eIncompletePacket;

    BYTE payloadFlags = buffer.getAt(0);
    if (payloadFlags != 129)
        return Protocol::eUndefinedFailure;

    BYTE basicSize = buffer.getAt(1) & 0x7F;
    unsigned __int64 payloadSize;
    int masksOffset;

    if (basicSize <= 125)
    {
        payloadSize = basicSize;
        masksOffset = 2;
    }
    else if (basicSize == 126)
    {
        nMinExpectedSize += 2;
        if (buffer.GetDataSize() < nMinExpectedSize)
            return Protocol::eIncompletePacket;
        payloadSize = ntohs( *(u_short*) (buffer.GetBuffer() + 2) );
        masksOffset = 4;
    }
    else if (basicSize == 127)
    {
        nMinExpectedSize += 8;
        if (buffer.GetDataSize() < nMinExpectedSize)
            return Protocol::eIncompletePacket;
        payloadSize = ntohl( *(u_long*) (buffer.GetBuffer() + 2) );
        masksOffset = 10;
    }
    else
        return Protocol::eUndefinedFailure;

    nMinExpectedSize += payloadSize;
    if (buffer.GetDataSize() < nMinExpectedSize)
        return Protocol::eIncompletePacket;

    BYTE masks[4];
    memcpy(masks, buffer.GetBuffer() + masksOffset, 4);

    char* payload = new char[payloadSize + 1];
    memcpy(payload, buffer.GetBuffer() + masksOffset + 4, payloadSize);
    for (unsigned __int64 i = 0; i < payloadSize; i++) {
        payload[i] = (payload[i] ^ masks[i%4]);
    }
	payload[payloadSize] = '\0';
  
	WebsocketDataMessage* pMessage = new WebsocketDataMessage(payload);
	serviceId = 1;
	nExtractedBytes = nMinExpectedSize;
	pPacket = pMessage;

	delete payload;
	return Protocol::Success;
}

int WebsocketProtocol::frameOutgoingPacket( PushFramework::OutgoingPacket& packet, PushFramework::DataBuffer& buffer, unsigned int& nWrittenBytes )
{
	WebsocketMessage& message = (WebsocketMessage&) packet;

	if (message.GetType() == WebsocketMessage::Handshake)
	{
		WebsocketHandshakeMessage& handshake = (WebsocketHandshakeMessage&) message;

        string strRaw = handshake.Serialize();
        if(strRaw.size() > buffer.getRemainingSize())
            return Protocol::eInsufficientBuffer;

        buffer.Append((char*)strRaw.c_str(), strRaw.size());
        nWrittenBytes = strRaw.size();
        return Protocol::Success;
	}

	if (message.GetType() == WebsocketMessage::DataMessage)
	{
		WebsocketDataMessage& dataMessage = (WebsocketDataMessage&) message;
        string strData = dataMessage.GetEncodedData();

		unsigned __int64 payloadSize = strData.size();

		int expectedSize = payloadSize + 1; //flags byte.
		if(payloadSize <= 125  && payloadSize <= 65535 )
			expectedSize += 1;
		else if(payloadSize > 125  && payloadSize <= 65535)
			expectedSize += 3;
		else
			expectedSize += 9;

		if (expectedSize > buffer.getRemainingSize())
			return Protocol::eInsufficientBuffer;

		//create the flags byte
		char payloadFlags = 129;
		buffer.Append(&payloadFlags, 1);

		//create the length byte
		if (payloadSize <= 125)
		{
			char basicSize = payloadSize;
			buffer.Append(&basicSize, 1);
		}
		else if (payloadSize > 125 & payloadSize <= 65535)
		{
			char basicSize = 126;
			buffer.Append(&basicSize, 1);

			char len[2];
			len[0] = ( payloadSize >> 8 ) & 255;
			len[1] = ( payloadSize ) & 255;
			buffer.Append(len, 2);
		}
		else
		{
			char basicSize = 127;
			buffer.Append(&basicSize, 1);

			char len[8];
			len[0] = ( payloadSize >> 56 ) & 255;
			len[1] = ( payloadSize >> 48  ) & 255;
			len[2] = ( payloadSize >> 40 ) & 255;
			len[3] = ( payloadSize >> 32  ) & 255;
			len[4] = ( payloadSize >> 24 ) & 255;
			len[5] = ( payloadSize >> 16  ) & 255;
			len[6] = ( payloadSize >> 8 ) & 255;
			len[7] = ( payloadSize ) & 255;
			buffer.Append(len, 8);
		}

        buffer.Append((char*)strData.c_str(), strData.size());
        nWrittenBytes = expectedSize;
        return Protocol::Success;
	}
	
	return Protocol::eUndefinedFailure;
}

bool WebsocketProtocol::ProcessHandshake( WebsocketHandshakeMessage& request, WebsocketHandshakeMessage& response )
{
    std::string server_key = request.GetField("Sec-WebSocket-Key");
    server_key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

    SHA1        sha;
    unsigned int    message_digest[5];

    sha.Reset();
    sha << server_key.c_str();

    sha.Result(message_digest);
    // convert sha1 hash bytes to network byte order because this sha1
    //  library works on ints rather than bytes
    for (int i = 0; i < 5; i++) {
        message_digest[i] = htonl(message_digest[i]);
    }

    server_key = base64_encode(
        reinterpret_cast<const unsigned char*>
        (message_digest),20
        );


    response.SetField("Upgrade", "websocket");
    response.SetField("Connection", "Upgrade");
    response.SetField("Sec-WebSocket-Accept", server_key);

    return true;
}

int WebsocketProtocol::decodeIncomingPacket( PushFramework::IncomingPacket* pPacket, int& serviceId )
{
	WebsocketMessage* pMessage = (WebsocketMessage*) pPacket;
	if (pMessage->GetType() == WebsocketMessage::DataMessage)
	{
		WebsocketDataMessage* pDataMessage = (WebsocketDataMessage*) pMessage;
		if (!pDataMessage->Decode())
		{
			return Protocol::eDecodingFailure;
		}
		serviceId = pDataMessage->GetServiceId();
		return Protocol::Success;
	}
	return Protocol::Success;
}

int WebsocketProtocol::encodeOutgoingPacket( PushFramework::OutgoingPacket& packet )
{
	WebsocketMessage& message = (WebsocketMessage&) packet;
	if (message.GetType() == WebsocketMessage::DataMessage)
	{
		WebsocketDataMessage& dataMessage = (WebsocketDataMessage&) message;
		return dataMessage.Encode() ? Protocol::Success : Protocol::eEncodingFailure;
	}
	return Protocol::Success;
}




By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Apache License, Version 2.0


Written By
Technical Lead
Tunisia Tunisia
Services:
http://www.pushframework.com/?page_id=890

Comments and Discussions