Click here to Skip to main content
15,885,365 members
Articles / Programming Languages / C++

RCF - Interprocess Communication for C++

Rate me:
Please Sign up or sign in to vote.
4.94/5 (147 votes)
25 Oct 2011CPOL20 min read 4.6M   8.4K   331  
A server/client IPC framework, using the C++ preprocessor as an IDL compiler.
//*****************************************************************************
// RCF - Remote Call Framework
// Copyright (c) 2005. All rights reserved.
// Developed by Jarl Lindrud.
// Contact: jlindrud@hotmail.com .
//*****************************************************************************

#include <RCF/TcpIocpServerTransport.hpp>

#include <RCF/RcfServer.hpp>
#include <RCF/TcpClientTransport.hpp>
#include <RCF/TcpEndpoint.hpp>
#include <RCF/Tools.hpp>
#include <RCF/UsingBsdSockets.hpp>

namespace RCF {

    Iocp::Iocp(int nMaxConcurrency) 
    { 
        m_hIOCP = NULL; 
        if (nMaxConcurrency != -1)
        {
            Create(nMaxConcurrency);
        }
    }

    Iocp::~Iocp() 
    { 
        if (m_hIOCP != NULL)
        {
            RCF_VERIFY(CloseHandle(m_hIOCP), "CloseHandle")(m_hIOCP);
        }
    }

    BOOL Iocp::Create(int nMaxConcurrency)
    {
        m_hIOCP = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, nMaxConcurrency);
        RCF_VERIFY(m_hIOCP != NULL, "CreateIoCompletionPort()");
        return(m_hIOCP != NULL);
    }

    BOOL Iocp::AssociateDevice(HANDLE hDevice, ULONG_PTR CompKey) 
    {
        BOOL fOk = (CreateIoCompletionPort(hDevice, m_hIOCP, CompKey, 0) == m_hIOCP);
        RCF_VERIFY(fOk, "CreateIoCompletionPort()" );
        return(fOk);
    }

    BOOL Iocp::AssociateSocket(SOCKET hSocket, ULONG_PTR CompKey) 
    {
        return AssociateDevice((HANDLE) hSocket, CompKey);
    }

    BOOL Iocp::PostStatus(ULONG_PTR CompKey, DWORD dwNumBytes, OVERLAPPED* po) 
    {
        BOOL fOk = PostQueuedCompletionStatus(m_hIOCP, dwNumBytes, CompKey, po);
        RCF_ASSERT(fOk);
        return(fOk);
    }

    BOOL Iocp::GetStatus(ULONG_PTR* pCompKey, PDWORD pdwNumBytes, OVERLAPPED** ppo, DWORD dwMilliseconds) 
    {
        return GetQueuedCompletionStatus(m_hIOCP, pdwNumBytes, pCompKey, ppo, dwMilliseconds);
    }

    WsaRecvFunctor::WsaRecvFunctor(WSAOVERLAPPED *pOverlapped, const int &fd, const bool &zombie,  ReadWriteMutex &rwm) : 
        fd(fd), 
        zombie(zombie), 
        rwm(rwm),
        mpOverlapped(pOverlapped),
        mError()
    {}

    void WsaRecvFunctor::operator()(char *buffer, std::size_t bufferLen)
    {
        RCF_ASSERT(mpOverlapped);
        WSABUF wsabuf = { static_cast<u_long>(bufferLen), buffer};
        DWORD dwReceived = 0;
        DWORD dwFlags = 0;
        int ret = -1;
        int err = 0;
        mError = 0;
        {
            ReadLock lock(rwm); // since we're reading zombie
            if (!zombie)
            {
                RCF_ASSERT(fd >0);
                ret = WSARecv(fd, &wsabuf, 1, &dwReceived, &dwFlags, mpOverlapped, NULL);
                mError = WSAGetLastError();
            }
            else
            {
                RCF_TRACE("fd closed by external thread")(fd);
            }
        }
        if (mError == S_OK || mError == WSA_IO_PENDING)
        {
            mError = 0;
        }
        else if (ret == -1 )
        {
            RCF_TRACE("")(mError);
        }
    }

    Filter::ReadFunction WsaRecvFunctor::getReadFunction()
    {
        return boost::bind( &WsaRecvFunctor::operator(), this, _1, _2);
    }

    int WsaRecvFunctor::getError()
    {
        return mError;
    }

    WsaSendFunctor::WsaSendFunctor(WSAOVERLAPPED *pOverlapped, const int &fd, const bool &zombie, ReadWriteMutex &rwm) : 
        fd(fd), 
        zombie(zombie), 
        rwm(rwm),
        mError(),
        mpOverlapped(pOverlapped)
    {}

    void WsaSendFunctor::operator()(const char *buffer, std::size_t bufferLen)
    {
        RCF_ASSERT(mpOverlapped);
        WSABUF wsabuf = { static_cast<u_long>(bufferLen), (char *) buffer };
        DWORD dwSent = 0;
        DWORD dwFlags = 0;
        int ret = -1;
        mError = 0;
        {
            ReadLock lock(rwm); // since we're reading zombie
            if (!zombie)
            {
                RCF_ASSERT(fd >0);
                ret = WSASend(fd, &wsabuf, 1, &dwSent, dwFlags, mpOverlapped, NULL);
                mError = WSAGetLastError();
            }
            else
            {
                RCF_TRACE("fd closed by external thread")(fd);
            }
        }
        if (mError == S_OK || mError == WSA_IO_PENDING)
        {
            mError = 0;
        }
        else if (ret == -1)
        {
            RCF_TRACE("")(mError);
        }
    }

    Filter::WriteFunction WsaSendFunctor::getWriteFunction()
    {
        return boost::bind( &WsaSendFunctor::operator(), this, _1, _2);
    }

    int WsaSendFunctor::getError()
    {
        return mError;
    }


    // Synchronization -
    // SessionState::fd is immutable, and is only physically closed (closesocket()) from the SessionState destructor, which is thread-safe
    // since we only refer to SessionState object's through shared_ptr's. Any thread wishing to close a session must set the 
    // SessionState::zombie flag. Eventually the SessionState will be ejected from the server by a server thread, and the destructor will close the connection.
    // Ownership of the fd can be taken from the SessionState, by setting ownFd to false. The server-wide read_write_mutex rwm is used to 
    // synchronize reads and writes of ownFd and zombie for all SessionState objects.

    // SessionState's with set zombie flags cannot be summarily removed from the servers session map, since that might delete the SessionState object
    // while the iocp is polling it, => core dump. They can only be safely removed when it is known that the iocp is not polling them.

    // Since we only handle SessionState's through shared_ptr's, we don't need to synchronize access to ownFd and zombie.
    TcpIocpServerTransport::SessionState::~SessionState()
    {
        RCF_TRACE("")(sessionPtr.get())(ownFd)(zombie)(fd);

        if (ownFd && !zombie && fd != -1)
        {
            RCF_VERIFY_SOCKETS(0 == Platform::OS::BsdSockets::closesocket(fd), "closesocket() failure")(fd);
        }
        zombie = true;
    }

    void TcpIocpServerTransport::SessionState::setTransportFilters(const std::vector<FilterPtr> &filters)
    {
        mTransportFilters.assign(filters.begin(), filters.end());
        RCF::connectFilters(
            mTransportFilters, 
            wsaRecvFunctor.getReadFunction(), 
            wsaSendFunctor.getWriteFunction(), 
            boost::bind(&TcpIocpServerTransport::onReadWriteCompleted, &transport, getWeakThisPtr().lock(), _1, _2));
    }

    int TcpIocpServerTransport::SessionState::read(char *buffer, std::size_t bufferLen)
    {
        mTransportFilters.empty() ?
            wsaRecvFunctor.getReadFunction()(buffer, bufferLen) :
            mTransportFilters.front()->read(buffer, bufferLen);

        return wsaRecvFunctor.getError() ? -1 : 0;
    }

    int TcpIocpServerTransport::SessionState::write(char *buffer, std::size_t bufferLen)
    {
        mTransportFilters.empty() ?
            wsaSendFunctor.getWriteFunction()(buffer, bufferLen) :
            mTransportFilters.front()->write(buffer, bufferLen);

        return wsaSendFunctor.getError() ? -1 : 0;
    }

    void TcpIocpServerTransport::SessionState::onReadWriteCompleted(std::size_t bytesTransferred, int error)
    {
        mTransportFilters.empty() ?
            transport.onReadWriteCompleted(this->getWeakThisPtr().lock(), bytesTransferred, error):
            mTransportFilters.back()->onReadWriteCompleted(bytesTransferred, error);
    }

    void TcpIocpServerTransport::SessionState::clearOverlapped()
    {
        memset(static_cast<OVERLAPPED *>(this), 0, sizeof(OVERLAPPED));
    }

    bool TcpIocpServerTransport::SessionState::isReflecting()
    {
        return reflectionFd != 0;
    }

    const I_RemoteAddress &TcpIocpServerTransport::SessionState::getRemoteAddress()
    {
        return remoteAddress;
    }

    TcpIocpServerTransport::Fd TcpIocpServerTransport::SessionState::getFd() const
    {
        return fd;
    }

    TcpIocpServerTransport::TcpIocpProactor::TcpIocpProactor(TcpIocpServerTransport &transport, boost::shared_ptr<SessionState> sessionStatePtr) :
        transport(transport), 
        sessionStatePtr(sessionStatePtr)
    {}

    void TcpIocpServerTransport::TcpIocpProactor::postRead()
    {
        transport.postRead(sessionStatePtr.lock());
    }

    void TcpIocpServerTransport::TcpIocpProactor::postWrite()
    {
        transport.postWrite(sessionStatePtr.lock());
    }

    void TcpIocpServerTransport::TcpIocpProactor::postClose()
    {}

    std::vector<char> &TcpIocpServerTransport::TcpIocpProactor::getWriteBuffer()
    {
        return sessionStatePtr.lock()->getWriteBuffer();
    }

    std::size_t TcpIocpServerTransport::TcpIocpProactor::getWriteOffset()
    {
        return 4;
    }

    std::vector<char> &TcpIocpServerTransport::TcpIocpProactor::getReadBuffer()
    {
        return sessionStatePtr.lock()->getReadBuffer();
    }

    std::size_t TcpIocpServerTransport::TcpIocpProactor::getReadOffset()
    {
        return 0;
    }

    I_ServerTransport &TcpIocpServerTransport::TcpIocpProactor::getServerTransport()
    {
        return transport;
    }

    TcpIocpServerTransport::SessionState &TcpIocpServerTransport::TcpIocpProactor::getSessionState()
    {
        return *getSessionStatePtr();
    }

    TcpIocpServerTransport::SessionStatePtr TcpIocpServerTransport::TcpIocpProactor::getSessionStatePtr()
    {
        return sessionStatePtr.lock();
    }

    const I_RemoteAddress &TcpIocpServerTransport::TcpIocpProactor::getRemoteAddress()
    {
        return sessionStatePtr.lock()->getRemoteAddress();
    }

    void TcpIocpServerTransport::TcpIocpProactor::setTransportFilters(const std::vector<FilterPtr> &filters)
    {
        return sessionStatePtr.lock()->setTransportFilters(filters);
    }

    TcpIocpServerTransport::TcpIocpServerTransport(int port) :
        rwm(ReaderPriority),
        pSessionManager(),
        maxPendingConnectionCount(100),
        fdPartitionCount(10),
        port(port),
        mStopFlag(),
        mOpen(),
        acceptorFd(-1),
        iocp(),
        mQueuedAccepts(0),
        mQueuedAcceptsThreshold(10),
        mQueuedAcceptsAugment(10),
        mlpfnAcceptEx(),
        mlpfnGetAcceptExSockAddrs()
    {}

    void TcpIocpServerTransport::setPort(int port)
    {
        this->port = port;
    }

    int TcpIocpServerTransport::getPort()
    {
        return port;
    }

    void TcpIocpServerTransport::setMaxPendingConnectionCount(unsigned int maxPendingConnectionCount)
    {
        this->maxPendingConnectionCount = maxPendingConnectionCount;
    }

    unsigned int TcpIocpServerTransport::getMaxPendingConnectionCount()
    {
        return maxPendingConnectionCount;
    }

    void TcpIocpServerTransport::open()
    {
        RCF_ASSERT(iocp.get() == NULL);
        RCF_ASSERT(acceptorFd == -1);
        RCF_ASSERT(sessionMaps.empty());
        RCF_ASSERT(port > 0);
        RCF_ASSERT(mQueuedAccepts == 0);
        
        // setup synchronized session maps
        for (unsigned int i=0; i<fdPartitionCount; i++)
        {
            sessionMaps.push_back( std::make_pair(MutexPtr(new Mutex), SessionStateMap()) );
        }

        // create listener socket
        int ret = 0;
        int err = 0;
        acceptorFd = static_cast<int>(socket(PF_INET, SOCK_STREAM, IPPROTO_TCP));
        if (acceptorFd == -1) 
        {
            err = Platform::OS::BsdSockets::GetLastError(); 
            RCF_THROW(ServerTransportException, "socket() failed")(acceptorFd)(err)(Platform::OS::GetErrorString(err)); 
        }

        // bind listener socket
        std::string networkInterface = getNetworkInterface();
        unsigned long ul_addr = inet_addr( networkInterface.c_str() );
        if (ul_addr == INADDR_NONE)
        {
            hostent *hostDesc = gethostbyname(networkInterface.c_str());
            if (hostDesc)
            {
                char *szIp = ::inet_ntoa( * (in_addr*) hostDesc->h_addr_list[0]);
                ul_addr = ::inet_addr(szIp);
            }
        }
        sockaddr_in serverAddr;
        memset(&serverAddr, 0, sizeof(serverAddr));
        serverAddr.sin_family = AF_INET;
        serverAddr.sin_addr.s_addr = ul_addr;
        serverAddr.sin_port = htons(port);
        ret = bind(acceptorFd, (struct sockaddr*) &serverAddr, sizeof(serverAddr));
        if (ret < 0) 
        {
            err = Platform::OS::BsdSockets::GetLastError(); 
            RCF_THROW(ServerTransportException, "bind() failed")(acceptorFd)(port)(networkInterface)(ret)(err)(Platform::OS::GetErrorString(err)); 
        }

        // listen on listener socket
        ret = listen(acceptorFd, maxPendingConnectionCount);
        if (ret < 0) 
        {
            err = Platform::OS::BsdSockets::GetLastError(); 
            RCF_THROW(ServerTransportException, "bind() failed")(acceptorFd)(ret)(err)(Platform::OS::GetErrorString(err)); 
        }
        RCF_ASSERT( acceptorFd != -1 )(acceptorFd);

        // create io completion port and associate the listener socket
        iocp.reset( new Iocp );
        iocp->Create();
        iocp->AssociateDevice( (HANDLE) acceptorFd, (ULONG_PTR) acceptorFd);

        // load AcceptEx() function
        GUID GuidAcceptEx = WSAID_ACCEPTEX;
        DWORD dwBytes;
        RCF_VERIFY_SOCKETS(
            0 == WSAIoctl(
                acceptorFd, 
                SIO_GET_EXTENSION_FUNCTION_POINTER, 
                &GuidAcceptEx, 
                sizeof(GuidAcceptEx),
                &mlpfnAcceptEx, 
                sizeof(mlpfnAcceptEx), 
                &dwBytes, 
                NULL, 
                NULL), 
            "WSAIoctl()");

        // load GetAcceptExSockAddrs() function
        GUID GuidGetAcceptExSockAddrs = WSAID_GETACCEPTEXSOCKADDRS;
        RCF_VERIFY_SOCKETS(
            0 == WSAIoctl(
                acceptorFd, 
                SIO_GET_EXTENSION_FUNCTION_POINTER, 
                &GuidGetAcceptExSockAddrs, 
                sizeof(GuidGetAcceptExSockAddrs),
                &mlpfnGetAcceptExSockAddrs, 
                sizeof(mlpfnGetAcceptExSockAddrs), 
                &dwBytes, 
                NULL, 
                NULL), 
            "WsaIoctl()");

    }

    void TcpIocpServerTransport::close()
    {
        // delete iocp
        iocp.reset();

        // close listener socket
        if (acceptorFd != -1)
        {
            RCF_VERIFY_SOCKETS(0 == closesocket(acceptorFd), "closesocket() failure")(acceptorFd);
            acceptorFd = -1;
        }

        // reset queued accepts count
        mQueuedAccepts = 0;

        // delete all sessions
        sessionMaps.clear(); 
    }

    TcpIocpServerTransport::Fd TcpIocpServerTransport::hash(Fd fd)
    {
        return fd % fdPartitionCount;
    }

    // synchronized - no shared resources
    TcpIocpServerTransport::SessionStatePtr TcpIocpServerTransport::createSession(int fd)
    {
        SessionStatePtr sessionStatePtr( new SessionState(*this, fd) );
        ProactorPtr proactorPtr( new TcpIocpProactor(*this, sessionStatePtr) );
        SessionPtr sessionPtr = getSessionManager().createSession();
        sessionPtr->setProactorPtr(proactorPtr);
        sessionStatePtr->setSessionPtr(sessionPtr);
        sessionStatePtr->setWeakThisPtr( SessionStateWeakPtr(sessionStatePtr) );
        return sessionStatePtr;
    }

    // synchronized
    bool TcpIocpServerTransport::monitorSession(SessionStatePtr sessionStatePtr)
    {
        ReadLock readLock(rwm); // since we're reading sessionStateMap[fd]->zombie

        int fd = sessionStatePtr->getFd();
        Mutex &sessionMapMutex = *sessionMaps[ hash(fd) ].first;
        Lock lock(sessionMapMutex); RCF_UNUSED_VARIABLE(lock);
        SessionStateMap &sessionStateMap = sessionMaps[ hash(fd) ].second;
        if (sessionStateMap[fd].get() && sessionStateMap[fd]->zombie)
        {
            return false;
        }
        else if (sessionStateMap[fd].get() && sessionStateMap[fd] == sessionStatePtr)
        {
            return true; // we're already monitoring this session
        }
        else
        {
            RCF_ASSERT(sessionStateMap[fd].get() == NULL);
            sessionStateMap[fd] = sessionStatePtr;
            return true;
        }
    }

    // synchronized
    bool TcpIocpServerTransport::unmonitorSession(SessionStatePtr sessionStatePtr)
    {
        int fd = sessionStatePtr->getFd();
        Mutex &sessionMapMutex = *sessionMaps[ hash(fd) ].first;
        Lock lock(sessionMapMutex); RCF_UNUSED_VARIABLE(lock);
        SessionStateMap &sessionStateMap = sessionMaps[ hash(fd) ].second;
        if (sessionStateMap[fd].get())
        {
            sessionStateMap[fd].reset();
        }
        return true;
    }

    void TcpIocpServerTransport::closeSession(SessionStatePtr sessionStatePtr)
    {
        RCF_TRACE("")(sessionStatePtr->getSessionPtr().get());

        int fd = sessionStatePtr->getFd();
        Mutex &sessionMapMutex = *sessionMaps[ hash(fd) ].first;
        Lock lock(sessionMapMutex); RCF_UNUSED_VARIABLE(lock);
        SessionStateMap &sessionStateMap = sessionMaps[ hash(fd) ].second;
        if (sessionStateMap[fd].get())
        {
            if (sessionStateMap[fd]->getState() == SessionState::Accepting)
            {
                InterlockedDecrement( (LONG *) &mQueuedAccepts);
                if (mQueuedAccepts < mQueuedAcceptsThreshold)
                {
                    mQueuedAcceptsCondition.notify_one();
                }
            }
            sessionStateMap[fd] = SessionStatePtr();

            //sessionStatePtr->close();
            // NB: socket isn't closed until the SessionState destructor is executed, which should be when the arg to this function goes out of scope

            if (sessionStatePtr->isReflecting())
            {
                SessionStatePtr reflectionSessionStatePtr = sessionStatePtr->getReflectionSessionStateWeakPtr().lock();
                if (reflectionSessionStatePtr.get())
                {
                    externalCloseSession(reflectionSessionStatePtr);
                }
            }
        }
        else
        {
            RCF_TRACE("server transport cannot close session - foreign fd");
        }
    }

    void TcpIocpServerTransport::transition(SessionStatePtr sessionStatePtr)
    {
        std::vector<char> &readBuffer = sessionStatePtr->getReadBuffer();
        std::size_t readBufferRemaining = sessionStatePtr->getReadBufferRemaining();
        std::vector<char> &writeBuffer = sessionStatePtr->getWriteBuffer();
        std::size_t writeBufferRemaining = sessionStatePtr->getWriteBufferRemaining();
        OVERLAPPED *pOverlapped = static_cast<OVERLAPPED *>(sessionStatePtr.get());

        switch(sessionStatePtr->getState())
        {
        case SessionState::Accepting:

            // parse the local and remote address info
            {
                SOCKADDR *pLocalAddr = NULL;
                SOCKADDR *pRemoteAddr = NULL;

                int localAddrLen = 0;
                int remoteAddrLen = 0;

                mlpfnGetAcceptExSockAddrs( 
                    &readBuffer[0], 
                    0, 
                    sizeof(sockaddr_in) + 16, 
                    sizeof(sockaddr_in) + 16, 
                    &pLocalAddr, 
                    &localAddrLen, 
                    &pRemoteAddr, 
                    &remoteAddrLen);

                sockaddr_in *pLocalSockAddr = reinterpret_cast<sockaddr_in *>(pLocalAddr);
                sessionStatePtr->setLocalAddress( IpAddress(*pLocalSockAddr) );
                
                sockaddr_in *pRemoteSockAddr = reinterpret_cast<sockaddr_in *>(pRemoteAddr);
                sessionStatePtr->setRemoteAddress( IpAddress(*pRemoteSockAddr) );
            }

            InterlockedDecrement( (LONG *) &mQueuedAccepts);

            if (mQueuedAccepts < mQueuedAcceptsThreshold)
            {
                mQueuedAcceptsCondition.notify_one();
            }

            // is this ip allowed?
            if (isClientAddrAllowed(sessionStatePtr->getRemoteSockAddr()))
            {
                // associate fd with iocp
                {
                    ReadLock lock(rwm); // since we're reading sessionStatePtr->zombie
                    if (sessionStatePtr->zombie)
                    {
                        closeSession(sessionStatePtr);
                        return;
                    }
                    else
                    {
                        int fd = sessionStatePtr->getFd();
                        RCF_VERIFY(1 == iocp->AssociateSocket(fd, fd), "AssociateSocket() failed")(fd);
                    }
                }

                // fake a write completion to get things moving
                sessionStatePtr->setState(SessionState::WritingData);
                sessionStatePtr->setWriteBufferRemaining(0);
                transition(sessionStatePtr);
            }
            else
            {
                closeSession(sessionStatePtr);
            }

            break;

        case SessionState::ReadingDataCount:

            RCF_ASSERT(0 <= readBufferRemaining && readBufferRemaining <= 4);
            if (readBufferRemaining == 0)
            {
                unsigned int packetLength = * (unsigned int *) (&readBuffer[0]);
                networkToMachineOrder(&packetLength, 4, 1);
                if (packetLength <= getMaxMessageLength())
                {
                    sessionStatePtr->getReadBuffer().resize(packetLength); // TODO: configurable limit on packetLength
                    sessionStatePtr->setReadBufferRemaining(packetLength);
                    sessionStatePtr->setState( SessionState::ReadingData );
                    transition(sessionStatePtr);
                }
                else
                {
                    closeSession(sessionStatePtr);
                }
            }
            else if (0 < readBufferRemaining && readBufferRemaining < 4)
            {
                char *readPos = & readBuffer[readBuffer.size() - readBufferRemaining];
                int ret = sessionStatePtr->read(readPos, readBufferRemaining);
                if (ret == -1)
                {
                    closeSession(sessionStatePtr);
                }

            }
            else if (readBufferRemaining == 4)
            {
                char *readPos = & readBuffer[0];
                int ret = sessionStatePtr->read(readPos, readBufferRemaining);
                if (ret == -1)
                {
                    closeSession(sessionStatePtr);
                }
            }

            break;

        case SessionState::ReadingData:

            if (readBufferRemaining == 0)
            {
                sessionStatePtr->setState( SessionState::Ready );
                getSessionManager().onReadCompleted( sessionStatePtr->getSessionPtr() );
            }
            else 
            {
                RCF_ASSERT( readBufferRemaining <= readBuffer.size() );
                char *readPos = & readBuffer[readBuffer.size() - readBufferRemaining];
                int ret = sessionStatePtr->read(readPos, readBufferRemaining);
                if (ret == -1)
                {
                    closeSession(sessionStatePtr);
                }
            }

            break;

        case SessionState::WritingData:

            RCF_ASSERT(writeBufferRemaining == 0 || writeBufferRemaining == writeBuffer.size());
            if (writeBufferRemaining == 0)
            {
                sessionStatePtr->setState( SessionState::Ready );
                getSessionManager().onWriteCompleted( sessionStatePtr->getSessionPtr() );
                return;
            }
            if (writeBufferRemaining == writeBuffer.size())
            {
                sessionStatePtr->clearOverlapped();
                int ret = sessionStatePtr->write(&writeBuffer[0], writeBufferRemaining);
                if (ret == -1)
                {
                    closeSession(sessionStatePtr);
                }
            }

            break;

        default:

            RCF_ASSERT(0)(sessionStatePtr->getState())(sessionStatePtr->getFd());
        }

    }

    void TcpIocpServerTransport::reflectSession(SessionStatePtr sessionStatePtr, DWORD bytesRead, ULONG_PTR completionKey)
    {
        RCF_UNUSED_VARIABLE(completionKey);
        int fd = sessionStatePtr->getFd();
        int reflectionFd = sessionStatePtr->getReflectionFd();
        SessionState::State state = sessionStatePtr->getState();
        RCF_ASSERT(state == SessionState::ReadingData || state == SessionState::ReadingDataCount || state == SessionState::WritingData);
        RCF_ASSERT(fd > 0 && reflectionFd > 0);
        if (state == SessionState::WritingData)
        {
            sessionStatePtr->setState( SessionState::ReadingData );
            sessionStatePtr->getReadBuffer().resize(1024); // decent STL implementation will only allocate memory once
            OVERLAPPED *pOverlapped = sessionStatePtr.get();
            WSAOVERLAPPED *pWsaOverlapped = reinterpret_cast<WSAOVERLAPPED *>(pOverlapped);
            u_long len = static_cast<u_long>(sessionStatePtr->getReadBuffer().size());
            char *buf = &sessionStatePtr->getReadBuffer()[0];
            WSABUF wsabuf = {len, buf};
            DWORD dwReceived = 0;
            DWORD dwFlags = 0;
            int ret = WSARecv(fd, &wsabuf, 1, &dwReceived, &dwFlags, pWsaOverlapped, NULL);
            int err = WSAGetLastError();
            if (ret == -1 && err != WSA_IO_PENDING)
            {
                RCF2_TRACE("")(ret)(err);
                closeSession(sessionStatePtr);
            }
        }
        else if (state == SessionState::ReadingData || state == SessionState::ReadingDataCount)
        {
            sessionStatePtr->setState( SessionState::WritingData );
            OVERLAPPED *pOverlapped = sessionStatePtr.get();
            WSAOVERLAPPED *pWsaOverlapped = reinterpret_cast<WSAOVERLAPPED *>(pOverlapped);
            WSABUF wsabuf = { bytesRead, (char *) &sessionStatePtr->getReadBuffer()[0] };
            DWORD dwSent = 0;
            DWORD dwFlags = 0;
            int ret = WSASend(reflectionFd, &wsabuf, 1, &dwSent, dwFlags, pWsaOverlapped, NULL);
            int err = WSAGetLastError();
            if (ret == -1 && err != WSA_IO_PENDING)
            {
                RCF2_TRACE("")(ret)(err);
                closeSession(sessionStatePtr);
            }
        }
    }

    bool TcpIocpServerTransport::cycleAccepts(int timeoutMs, const volatile bool &stopFlag)
    {
        if (timeoutMs == 0)
        {
            generateAccepts();
        }
        else
        {
            Lock lock(mQueuedAcceptsMutex);
            if (!stopFlag && !mStopFlag)
            {
                mQueuedAcceptsCondition.wait(lock);
                if (!stopFlag && !mStopFlag)
                {
                    generateAccepts();
                }
                else
                {
                    return true;
                }
            }
        }
        return stopFlag || mStopFlag;
    }

    void TcpIocpServerTransport::stopAccepts()
    {
        mStopFlag = true;
        Lock lock(mQueuedAcceptsMutex);
        mQueuedAcceptsCondition.notify_one();
    }

    void TcpIocpServerTransport::generateAccepts()
    {
        if (mQueuedAccepts < mQueuedAcceptsThreshold)
        {
            for (unsigned int i=0; i<mQueuedAcceptsAugment; i++)
            {
                Fd fd = static_cast<Fd>(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
                RCF_VERIFY_SOCKETS(fd != -1, "socket() failed");
                SessionStatePtr sessionStatePtr = createSession(fd);
                if (monitorSession(sessionStatePtr))
                {
                    sessionStatePtr->getReadBuffer().resize(2*(sizeof(sockaddr_in) + 16));
                    sessionStatePtr->setReadBufferRemaining( 2*(sizeof(sockaddr_in) + 16) );
                    DWORD dwBytes = 0;

                    for (unsigned int i=0; i<2*(sizeof(sockaddr_in) + 16); i++)
                    {
                        sessionStatePtr->getReadBuffer()[i] = 0;
                    }

                    sessionStatePtr->clearOverlapped();

                    BOOL ret = mlpfnAcceptEx(
                        acceptorFd, 
                        fd,
                        &sessionStatePtr->getReadBuffer()[0], 
                        0,
                        sizeof(sockaddr_in) + 16, 
                        sizeof(sockaddr_in) + 16, 
                        &dwBytes, 
                        static_cast<OVERLAPPED *>(sessionStatePtr.get()));
                    int err = WSAGetLastError();

                    if (ret == FALSE && err == ERROR_IO_PENDING)
                    {
                        // async accept initiated successfully
                    }
                    else if (dwBytes > 0)
                    {
                        RCF_ASSERT(0);
                        transition(sessionStatePtr);
                    }
                    else
                    {
                        RCF_THROW(ServerTransportException, "AcceptEx failed")(err);
                    }

                    BOOST_STATIC_ASSERT( sizeof(LONG) == sizeof(mQueuedAccepts) );
                    InterlockedIncrement( (LONG *) &mQueuedAccepts);
                }
                else
                {
                    RCF2_TRACE("");
                }
            }
        }
    }
  
    void TcpIocpServerTransport::cycle(int timeoutMs, const volatile bool &stopFlag)
    {

        RCF_UNUSED_VARIABLE(stopFlag);

        if (mQueuedAccepts < mQueuedAcceptsThreshold)
        {
            mQueuedAcceptsCondition.notify_one();
        }

        // extract a completed io operation from the iocp
        DWORD dwMilliseconds = timeoutMs < 0 ? INFINITE : timeoutMs;
        DWORD dwNumBytes = 0;
        ULONG_PTR completionKey = 0;
        OVERLAPPED *pOverlapped = 0;
        BOOL ret = iocp->GetStatus(&completionKey, &dwNumBytes, &pOverlapped, dwMilliseconds);
        DWORD dwErr = GetLastError();
        if (pOverlapped == NULL)
        {
            if (dwErr == WAIT_TIMEOUT) 
            {
                // timed out
            } 
            else 
            {
                // Bad call to GetQueuedCompletionStatus, dwErr contains the reason for the bad call
                RCF_ASSERT(0)(dwErr);
            }
        }
        else
        {
            SessionState *pSessionState = static_cast<SessionState *>(pOverlapped);
            int fd = pSessionState->getFd(); // if we're completing an AcceptEx call, this won't be the same as completion key
            SessionStatePtr sessionStatePtr = getSessionStatePtr(fd);
            if (sessionStatePtr.get() == NULL)
            {
                // somebody removed the session from the server transport
            }
            else if ( (ret == 0 && dwErr == WSA_OPERATION_ABORTED) || ret == 0)
            {
                closeSession(sessionStatePtr);
            }
            else 
            {
                ReadLock lock(rwm);
                if (sessionStatePtr->zombie)
                {
                    closeSession(sessionStatePtr);
                }
                else
                {
                    lock.unlock();

                    RCF_ASSERT(sessionStatePtr.get());
                    if (completionKey == acceptorFd)
                    {
                        // accept completed
                        onAcceptCompleted(sessionStatePtr);
                    }
                    else if (dwNumBytes == 0)
                    {
                        // socket was closed, probably by the remote host
                        closeSession(sessionStatePtr);
                    }
                    else if (sessionStatePtr->isReflecting())
                    {
                        // read or write completed on a reflection pair
                        reflectSession(sessionStatePtr, dwNumBytes, completionKey);
                    }
                    else
                    {
                        // read or write completed
                        int bytesRead = dwNumBytes;
                        sessionStatePtr->onReadWriteCompleted(bytesRead, 0);
                    }
                }
            }
        }
    }

    void TcpIocpServerTransport::onReadWriteCompleted(SessionStatePtr sessionStatePtr, std::size_t bytesTransferred, int error)
    {
        if (sessionStatePtr->getState() == SessionState::ReadingData || 
            sessionStatePtr->getState() == SessionState::ReadingDataCount)
        {
            sessionStatePtr->setReadBufferRemaining( sessionStatePtr->getReadBufferRemaining() - bytesTransferred);
            transition(sessionStatePtr);    
        }
        else if (sessionStatePtr->getState() == SessionState::WritingData)
        {
            // TODO: handle partially completed writes
            RCF_ASSERT(bytesTransferred == sessionStatePtr->getWriteBufferRemaining());
            sessionStatePtr->setWriteBufferRemaining(0);
            transition(sessionStatePtr);    
        }
    }
  
    void TcpIocpServerTransport::onAcceptCompleted(SessionStatePtr sessionStatePtr)
    {
        transition(sessionStatePtr);
    }

    void TcpIocpServerTransport::postWrite(SessionStatePtr sessionStatePtr)
    {
        RCF_ASSERT(sessionStatePtr->getState() == SessionState::Ready);
        RCF_ASSERT(sizeof(unsigned int) == 4);
        RCF_ASSERT(sessionStatePtr->getWriteBuffer().size() > 4);
        sessionStatePtr->setState( SessionState::WritingData );
        sessionStatePtr->setWriteBufferRemaining( static_cast<unsigned int>(sessionStatePtr->getWriteBuffer().size()) );
        RCF_ASSERT(sessionStatePtr->getWriteBuffer().size() >= 4);
        *(unsigned int*) &sessionStatePtr->getWriteBuffer()[0] = static_cast<unsigned int>(sessionStatePtr->getWriteBuffer().size()-4);
        RCF::machineToNetworkOrder(&sessionStatePtr->getWriteBuffer()[0], 4, 1);

        transition(sessionStatePtr);
    }

    void TcpIocpServerTransport::postRead(SessionStatePtr sessionStatePtr)
    {
        RCF_ASSERT(sessionStatePtr->getState() == SessionState::Ready);
        sessionStatePtr->setState(SessionState::ReadingDataCount);
        sessionStatePtr->getReadBuffer().resize(4);
        sessionStatePtr->setReadBufferRemaining(4);

        transition(sessionStatePtr);
    }

    // synchronized
    TcpIocpServerTransport::SessionStatePtr TcpIocpServerTransport::getSessionStatePtr(Fd fd) 
    {
        Lock lock( *sessionMaps[hash(fd)].first );
        return sessionMaps[hash(fd)].second[fd];
    }

    void TcpIocpServerTransport::externalCloseSession(SessionStatePtr sessionStatePtr)
    {
        int fd = sessionStatePtr->getFd();
        externalCloseSession(sessionStatePtr, fd);
    }

    void TcpIocpServerTransport::externalCloseSession(SessionStatePtr sessionStatePtr, int &fd)
    {
        WriteLock lock(rwm); // since we're writing sessionStatePtr->zombie
        if (sessionStatePtr.get())
        {
            if (!sessionStatePtr->zombie)
            {
                RCF_VERIFY_SOCKETS(0 == closesocket(sessionStatePtr->getFd()), "closesocket() failure")(fd)(sessionStatePtr->getFd());
                sessionStatePtr->zombie = true;
            }
            fd = -1;
        }
        else
        {
            RCF_ASSERT(0);
            RCF_VERIFY_SOCKETS(0 == closesocket(fd), "closesocket() failure")(sessionStatePtr->getFd());
            fd = -1;
        }

        // actual removal of the the SessionState doesn't occur until the session receives a close event from the iocp.
        // generally that would be the last reference to the SessionState, and so the SessionState would then also be deallocated.
    }

    // alias for boost.bind
    void TcpIocpServerTransport::externalCloseSession0(SessionStatePtr sessionStatePtr, int &fd)
    {
        externalCloseSession(sessionStatePtr, fd);
    }

    // create a server-aware client transport on the connection associated with this session. fd is owned by the client, not the server session.
    // will only create a client transport the first time it is called, after that an empty auto_ptr is returned.
    ClientTransportAutoPtr TcpIocpServerTransport::createClientTransport(boost::shared_ptr<I_Session> sessionPtr)
    {
        ProactorPtr proactorPtr = sessionPtr->getProactorPtr();
        TcpIocpProactor &tcpIocpProactor = dynamic_cast<TcpIocpProactor &>(*proactorPtr);
        SessionStatePtr sessionStatePtr = tcpIocpProactor.getSessionStatePtr();
        std::auto_ptr<TcpClientTransport> clientTransport;
        WriteLock lock(rwm); // since we're reading and writing sessionStatePtr->ownFd
        if (sessionStatePtr->ownFd)
        {
            int fd = sessionStatePtr->getFd();
            sessionStatePtr->ownFd = false;
            clientTransport.reset( new TcpClientTransport(fd) );
            clientTransport->setRemoteAddr( sessionStatePtr->getRemoteSockAddr() );
            clientTransport->setCloseFunctor( boost::bind(&TcpIocpServerTransport::externalCloseSession0, this, sessionStatePtr, _1) );
        }
        return ClientTransportAutoPtr(clientTransport.release());
    }

    // create a server session on the connection associated with the client transport
    boost::shared_ptr<I_Session> TcpIocpServerTransport::createServerSession(ClientTransportAutoPtr clientTransportAutoPtr)
    {
        TcpClientTransport &tcpClientTransport = dynamic_cast<TcpClientTransport &>(*clientTransportAutoPtr);
        int fd = tcpClientTransport.releaseFd();
        RCF_ASSERT(fd > 0);
        SessionStatePtr sessionStatePtr = createSession(fd);
        sessionStatePtr->setRemoteAddress( IpAddress(tcpClientTransport.getRemoteAddr()) );
        sessionStatePtr->setState(SessionState::WritingData);
        sessionStatePtr->setWriteBufferRemaining(0);
        if (!monitorSession(sessionStatePtr))
        {
            return boost::shared_ptr<I_Session>();
        }
        else
        {
            {
                ReadLock lock(rwm); // since we're reading sessionStatePtr->zombie
                if (!sessionStatePtr->zombie)
                {
                    int fd = sessionStatePtr->getFd();
                    RCF_VERIFY_SOCKETS(1 == iocp->AssociateSocket(fd, fd), "Iocp::AssociateSocket() failed()");
                }
            }
            transition(sessionStatePtr);

            return sessionStatePtr->getSessionPtr();
        }
    }

    // start reflecting data between the two given sessions
    bool TcpIocpServerTransport::reflect(boost::shared_ptr<I_Session> sessionPtr1, boost::shared_ptr<I_Session> sessionPtr2)
    {
        ProactorPtr proactorPtr1 = sessionPtr1->getProactorPtr();
        ProactorPtr proactorPtr2 = sessionPtr2->getProactorPtr();
        TcpIocpProactor &tcpIocpProactor1 = dynamic_cast<TcpIocpProactor &>(*proactorPtr1);
        TcpIocpProactor &tcpIocpProactor2 = dynamic_cast<TcpIocpProactor &>(*proactorPtr2);
        return reflect(tcpIocpProactor1.getSessionStatePtr(), tcpIocpProactor2.getSessionStatePtr());
    }

    bool TcpIocpServerTransport::reflect(SessionStatePtr sessionStatePtr1, SessionStatePtr sessionStatePtr2)
    {
        RCF_ASSERT(sessionStatePtr1.get() && sessionStatePtr2.get())(sessionStatePtr1.get())(sessionStatePtr2.get());
        if (monitorSession(sessionStatePtr1) && monitorSession(sessionStatePtr2))
        {
            WriteLock lock(rwm); // what sync do we need here?

            sessionStatePtr1->setReflectionSessionStateWeakPtr( SessionStateWeakPtr(sessionStatePtr2) );
            sessionStatePtr1->setReflectionFd( sessionStatePtr2->getFd() );

            sessionStatePtr2->setReflectionSessionStateWeakPtr( SessionStateWeakPtr(sessionStatePtr1) );
            sessionStatePtr2->setReflectionFd( sessionStatePtr1->getFd() );

            // all pending events for sessionStatePtr1 and sessionStatePtr2 will now be reflected
            return true;
        }
        return false;
    }

    // check if a server session is still connected
    bool TcpIocpServerTransport::isConnected(boost::shared_ptr<I_Session> sessionPtr)
    {
        ProactorPtr proactorPtr = sessionPtr->getProactorPtr();
        TcpIocpProactor &tcpIocpProactor = dynamic_cast<TcpIocpProactor &>(*proactorPtr);
        SessionStatePtr sessionStatePtr = tcpIocpProactor.getSessionStatePtr();
        ReadLock lock(rwm); // since we're reading sessionStatePtr->zombie
        return sessionStatePtr->getFd() > 0 && !sessionStatePtr->zombie;
    }

    // create a server-aware client transport to given endpoint
    ClientTransportAutoPtr TcpIocpServerTransport::createClientTransport(const I_Endpoint &endpoint)
    {
        const TcpEndpoint &tcpEndpoint = dynamic_cast<const TcpEndpoint &>(endpoint);
        return std::auto_ptr<I_ClientTransport>( new TcpClientTransport(tcpEndpoint.getIp(), tcpEndpoint.getPort()) );
    }

    void TcpIocpServerTransport::setSessionManager(I_SessionManager &sessionManager)
    {
        pSessionManager = &sessionManager;
    }

    I_SessionManager &TcpIocpServerTransport::getSessionManager()
    {
        return *pSessionManager;
    }

    bool TcpIocpServerTransport::cycleTransportAndServer(RcfServer &server, int timeoutMs, const volatile bool &stopFlag)
    {
        if (!stopFlag && !mStopFlag)
        {
            cycle(timeoutMs/2, stopFlag);
            server.cycleSessions(timeoutMs/2, stopFlag);
        }
        return stopFlag || mStopFlag;
    }

    void TcpIocpServerTransport::onServiceAdded(RcfServer &server)
    {
        setSessionManager(server);

        WriteLock writeLock( getTaskEntriesMutex() );
        getTaskEntries().clear();

        getTaskEntries().push_back( 
            TaskEntry(
                boost::bind(&TcpIocpServerTransport::cycleTransportAndServer, this, boost::ref(server), _1, _2),
                StopFunctor()));

        getTaskEntries().push_back( 
            TaskEntry(
                boost::bind(&TcpIocpServerTransport::cycleAccepts, this, _1, _2), 
                boost::bind(&TcpIocpServerTransport::stopAccepts, this)) );

        mStopFlag = false;
    }

    void TcpIocpServerTransport::onServiceRemoved(RcfServer &)
    {}
    
#ifdef _MSC_VER
#pragma warning( push )
#pragma warning( disable : 4355 )  // warning C4355: 'this' : used in base member initializer list
#endif

    TcpIocpServerTransport::SessionState::SessionState(TcpIocpServerTransport &transport, Fd fd) :
        state(SessionState::Accepting),
        reflectionFd(),
        readBufferRemaining(),
        writeBufferRemaining(),
        fd(fd),
        ownFd(true),
        zombie(),
        wsaRecvFunctor(this, this->fd, zombie, transport.rwm),
        wsaSendFunctor(this, this->fd, zombie, transport.rwm),
        transport(transport)
    {
        // blank the OVERLAPPED structure
        clearOverlapped();
    }

#ifdef _MSC_VER
#pragma warning( pop )
#endif

    void TcpIocpServerTransport::onServerStart(RcfServer &)
    {
        if (!mOpen)
        {
            open();
            mOpen = true;
        }
    }

    void TcpIocpServerTransport::onServerStop(RcfServer &)
    {
        if (mOpen)
        {
            close();
            mOpen = false;
            mStopFlag = false;
        }
    }

    void TcpIocpServerTransport::onServerOpen(RcfServer &)
    {
        if (!mOpen)
        {
            open();
            mOpen = true;
        }

    }

    void TcpIocpServerTransport::onServerClose(RcfServer &)
    {
        if (mOpen)
        {
            close();
            mOpen = false;
        }
    }

} // namespace RCF

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
Australia Australia
Software developer, from Sweden and now living in Canberra, Australia, working on distributed C++ applications. When he is not programming, Jarl enjoys skiing and playing table tennis. He derives immense satisfaction from referring to himself in third person.

Comments and Discussions