//*****************************************************************************
// RCF - Remote Call Framework
// Copyright (c) 2005. All rights reserved.
// Developed by Jarl Lindrud.
// Contact: jlindrud@hotmail.com .
//*****************************************************************************
#include <RCF/TcpIocpServerTransport.hpp>
#include <RCF/RcfServer.hpp>
#include <RCF/TcpClientTransport.hpp>
#include <RCF/TcpEndpoint.hpp>
#include <RCF/Tools.hpp>
#include <RCF/UsingBsdSockets.hpp>
namespace RCF {
Iocp::Iocp(int nMaxConcurrency)
{
m_hIOCP = NULL;
if (nMaxConcurrency != -1)
{
Create(nMaxConcurrency);
}
}
Iocp::~Iocp()
{
if (m_hIOCP != NULL)
{
RCF_VERIFY(CloseHandle(m_hIOCP), "CloseHandle")(m_hIOCP);
}
}
BOOL Iocp::Create(int nMaxConcurrency)
{
m_hIOCP = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, nMaxConcurrency);
RCF_VERIFY(m_hIOCP != NULL, "CreateIoCompletionPort()");
return(m_hIOCP != NULL);
}
BOOL Iocp::AssociateDevice(HANDLE hDevice, ULONG_PTR CompKey)
{
BOOL fOk = (CreateIoCompletionPort(hDevice, m_hIOCP, CompKey, 0) == m_hIOCP);
RCF_VERIFY(fOk, "CreateIoCompletionPort()" );
return(fOk);
}
BOOL Iocp::AssociateSocket(SOCKET hSocket, ULONG_PTR CompKey)
{
return AssociateDevice((HANDLE) hSocket, CompKey);
}
BOOL Iocp::PostStatus(ULONG_PTR CompKey, DWORD dwNumBytes, OVERLAPPED* po)
{
BOOL fOk = PostQueuedCompletionStatus(m_hIOCP, dwNumBytes, CompKey, po);
RCF_ASSERT(fOk);
return(fOk);
}
BOOL Iocp::GetStatus(ULONG_PTR* pCompKey, PDWORD pdwNumBytes, OVERLAPPED** ppo, DWORD dwMilliseconds)
{
return GetQueuedCompletionStatus(m_hIOCP, pdwNumBytes, pCompKey, ppo, dwMilliseconds);
}
WsaRecvFunctor::WsaRecvFunctor(WSAOVERLAPPED *pOverlapped, const int &fd, const bool &zombie, ReadWriteMutex &rwm) :
fd(fd),
zombie(zombie),
rwm(rwm),
mpOverlapped(pOverlapped),
mError()
{}
void WsaRecvFunctor::operator()(char *buffer, std::size_t bufferLen)
{
RCF_ASSERT(mpOverlapped);
WSABUF wsabuf = { static_cast<u_long>(bufferLen), buffer};
DWORD dwReceived = 0;
DWORD dwFlags = 0;
int ret = -1;
int err = 0;
mError = 0;
{
ReadLock lock(rwm); // since we're reading zombie
if (!zombie)
{
RCF_ASSERT(fd >0);
ret = WSARecv(fd, &wsabuf, 1, &dwReceived, &dwFlags, mpOverlapped, NULL);
mError = WSAGetLastError();
}
else
{
RCF_TRACE("fd closed by external thread")(fd);
}
}
if (mError == S_OK || mError == WSA_IO_PENDING)
{
mError = 0;
}
else if (ret == -1 )
{
RCF_TRACE("")(mError);
}
}
Filter::ReadFunction WsaRecvFunctor::getReadFunction()
{
return boost::bind( &WsaRecvFunctor::operator(), this, _1, _2);
}
int WsaRecvFunctor::getError()
{
return mError;
}
WsaSendFunctor::WsaSendFunctor(WSAOVERLAPPED *pOverlapped, const int &fd, const bool &zombie, ReadWriteMutex &rwm) :
fd(fd),
zombie(zombie),
rwm(rwm),
mError(),
mpOverlapped(pOverlapped)
{}
void WsaSendFunctor::operator()(const char *buffer, std::size_t bufferLen)
{
RCF_ASSERT(mpOverlapped);
WSABUF wsabuf = { static_cast<u_long>(bufferLen), (char *) buffer };
DWORD dwSent = 0;
DWORD dwFlags = 0;
int ret = -1;
mError = 0;
{
ReadLock lock(rwm); // since we're reading zombie
if (!zombie)
{
RCF_ASSERT(fd >0);
ret = WSASend(fd, &wsabuf, 1, &dwSent, dwFlags, mpOverlapped, NULL);
mError = WSAGetLastError();
}
else
{
RCF_TRACE("fd closed by external thread")(fd);
}
}
if (mError == S_OK || mError == WSA_IO_PENDING)
{
mError = 0;
}
else if (ret == -1)
{
RCF_TRACE("")(mError);
}
}
Filter::WriteFunction WsaSendFunctor::getWriteFunction()
{
return boost::bind( &WsaSendFunctor::operator(), this, _1, _2);
}
int WsaSendFunctor::getError()
{
return mError;
}
// Synchronization -
// SessionState::fd is immutable, and is only physically closed (closesocket()) from the SessionState destructor, which is thread-safe
// since we only refer to SessionState object's through shared_ptr's. Any thread wishing to close a session must set the
// SessionState::zombie flag. Eventually the SessionState will be ejected from the server by a server thread, and the destructor will close the connection.
// Ownership of the fd can be taken from the SessionState, by setting ownFd to false. The server-wide read_write_mutex rwm is used to
// synchronize reads and writes of ownFd and zombie for all SessionState objects.
// SessionState's with set zombie flags cannot be summarily removed from the servers session map, since that might delete the SessionState object
// while the iocp is polling it, => core dump. They can only be safely removed when it is known that the iocp is not polling them.
// Since we only handle SessionState's through shared_ptr's, we don't need to synchronize access to ownFd and zombie.
TcpIocpServerTransport::SessionState::~SessionState()
{
RCF_TRACE("")(sessionPtr.get())(ownFd)(zombie)(fd);
if (ownFd && !zombie && fd != -1)
{
RCF_VERIFY_SOCKETS(0 == Platform::OS::BsdSockets::closesocket(fd), "closesocket() failure")(fd);
}
zombie = true;
}
void TcpIocpServerTransport::SessionState::setTransportFilters(const std::vector<FilterPtr> &filters)
{
mTransportFilters.assign(filters.begin(), filters.end());
RCF::connectFilters(
mTransportFilters,
wsaRecvFunctor.getReadFunction(),
wsaSendFunctor.getWriteFunction(),
boost::bind(&TcpIocpServerTransport::onReadWriteCompleted, &transport, getWeakThisPtr().lock(), _1, _2));
}
int TcpIocpServerTransport::SessionState::read(char *buffer, std::size_t bufferLen)
{
mTransportFilters.empty() ?
wsaRecvFunctor.getReadFunction()(buffer, bufferLen) :
mTransportFilters.front()->read(buffer, bufferLen);
return wsaRecvFunctor.getError() ? -1 : 0;
}
int TcpIocpServerTransport::SessionState::write(char *buffer, std::size_t bufferLen)
{
mTransportFilters.empty() ?
wsaSendFunctor.getWriteFunction()(buffer, bufferLen) :
mTransportFilters.front()->write(buffer, bufferLen);
return wsaSendFunctor.getError() ? -1 : 0;
}
void TcpIocpServerTransport::SessionState::onReadWriteCompleted(std::size_t bytesTransferred, int error)
{
mTransportFilters.empty() ?
transport.onReadWriteCompleted(this->getWeakThisPtr().lock(), bytesTransferred, error):
mTransportFilters.back()->onReadWriteCompleted(bytesTransferred, error);
}
void TcpIocpServerTransport::SessionState::clearOverlapped()
{
memset(static_cast<OVERLAPPED *>(this), 0, sizeof(OVERLAPPED));
}
bool TcpIocpServerTransport::SessionState::isReflecting()
{
return reflectionFd != 0;
}
const I_RemoteAddress &TcpIocpServerTransport::SessionState::getRemoteAddress()
{
return remoteAddress;
}
TcpIocpServerTransport::Fd TcpIocpServerTransport::SessionState::getFd() const
{
return fd;
}
TcpIocpServerTransport::TcpIocpProactor::TcpIocpProactor(TcpIocpServerTransport &transport, boost::shared_ptr<SessionState> sessionStatePtr) :
transport(transport),
sessionStatePtr(sessionStatePtr)
{}
void TcpIocpServerTransport::TcpIocpProactor::postRead()
{
transport.postRead(sessionStatePtr.lock());
}
void TcpIocpServerTransport::TcpIocpProactor::postWrite()
{
transport.postWrite(sessionStatePtr.lock());
}
void TcpIocpServerTransport::TcpIocpProactor::postClose()
{}
std::vector<char> &TcpIocpServerTransport::TcpIocpProactor::getWriteBuffer()
{
return sessionStatePtr.lock()->getWriteBuffer();
}
std::size_t TcpIocpServerTransport::TcpIocpProactor::getWriteOffset()
{
return 4;
}
std::vector<char> &TcpIocpServerTransport::TcpIocpProactor::getReadBuffer()
{
return sessionStatePtr.lock()->getReadBuffer();
}
std::size_t TcpIocpServerTransport::TcpIocpProactor::getReadOffset()
{
return 0;
}
I_ServerTransport &TcpIocpServerTransport::TcpIocpProactor::getServerTransport()
{
return transport;
}
TcpIocpServerTransport::SessionState &TcpIocpServerTransport::TcpIocpProactor::getSessionState()
{
return *getSessionStatePtr();
}
TcpIocpServerTransport::SessionStatePtr TcpIocpServerTransport::TcpIocpProactor::getSessionStatePtr()
{
return sessionStatePtr.lock();
}
const I_RemoteAddress &TcpIocpServerTransport::TcpIocpProactor::getRemoteAddress()
{
return sessionStatePtr.lock()->getRemoteAddress();
}
void TcpIocpServerTransport::TcpIocpProactor::setTransportFilters(const std::vector<FilterPtr> &filters)
{
return sessionStatePtr.lock()->setTransportFilters(filters);
}
TcpIocpServerTransport::TcpIocpServerTransport(int port) :
rwm(ReaderPriority),
pSessionManager(),
maxPendingConnectionCount(100),
fdPartitionCount(10),
port(port),
mStopFlag(),
mOpen(),
acceptorFd(-1),
iocp(),
mQueuedAccepts(0),
mQueuedAcceptsThreshold(10),
mQueuedAcceptsAugment(10),
mlpfnAcceptEx(),
mlpfnGetAcceptExSockAddrs()
{}
void TcpIocpServerTransport::setPort(int port)
{
this->port = port;
}
int TcpIocpServerTransport::getPort()
{
return port;
}
void TcpIocpServerTransport::setMaxPendingConnectionCount(unsigned int maxPendingConnectionCount)
{
this->maxPendingConnectionCount = maxPendingConnectionCount;
}
unsigned int TcpIocpServerTransport::getMaxPendingConnectionCount()
{
return maxPendingConnectionCount;
}
void TcpIocpServerTransport::open()
{
RCF_ASSERT(iocp.get() == NULL);
RCF_ASSERT(acceptorFd == -1);
RCF_ASSERT(sessionMaps.empty());
RCF_ASSERT(port > 0);
RCF_ASSERT(mQueuedAccepts == 0);
// setup synchronized session maps
for (unsigned int i=0; i<fdPartitionCount; i++)
{
sessionMaps.push_back( std::make_pair(MutexPtr(new Mutex), SessionStateMap()) );
}
// create listener socket
int ret = 0;
int err = 0;
acceptorFd = static_cast<int>(socket(PF_INET, SOCK_STREAM, IPPROTO_TCP));
if (acceptorFd == -1)
{
err = Platform::OS::BsdSockets::GetLastError();
RCF_THROW(ServerTransportException, "socket() failed")(acceptorFd)(err)(Platform::OS::GetErrorString(err));
}
// bind listener socket
std::string networkInterface = getNetworkInterface();
unsigned long ul_addr = inet_addr( networkInterface.c_str() );
if (ul_addr == INADDR_NONE)
{
hostent *hostDesc = gethostbyname(networkInterface.c_str());
if (hostDesc)
{
char *szIp = ::inet_ntoa( * (in_addr*) hostDesc->h_addr_list[0]);
ul_addr = ::inet_addr(szIp);
}
}
sockaddr_in serverAddr;
memset(&serverAddr, 0, sizeof(serverAddr));
serverAddr.sin_family = AF_INET;
serverAddr.sin_addr.s_addr = ul_addr;
serverAddr.sin_port = htons(port);
ret = bind(acceptorFd, (struct sockaddr*) &serverAddr, sizeof(serverAddr));
if (ret < 0)
{
err = Platform::OS::BsdSockets::GetLastError();
RCF_THROW(ServerTransportException, "bind() failed")(acceptorFd)(port)(networkInterface)(ret)(err)(Platform::OS::GetErrorString(err));
}
// listen on listener socket
ret = listen(acceptorFd, maxPendingConnectionCount);
if (ret < 0)
{
err = Platform::OS::BsdSockets::GetLastError();
RCF_THROW(ServerTransportException, "bind() failed")(acceptorFd)(ret)(err)(Platform::OS::GetErrorString(err));
}
RCF_ASSERT( acceptorFd != -1 )(acceptorFd);
// create io completion port and associate the listener socket
iocp.reset( new Iocp );
iocp->Create();
iocp->AssociateDevice( (HANDLE) acceptorFd, (ULONG_PTR) acceptorFd);
// load AcceptEx() function
GUID GuidAcceptEx = WSAID_ACCEPTEX;
DWORD dwBytes;
RCF_VERIFY_SOCKETS(
0 == WSAIoctl(
acceptorFd,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&GuidAcceptEx,
sizeof(GuidAcceptEx),
&mlpfnAcceptEx,
sizeof(mlpfnAcceptEx),
&dwBytes,
NULL,
NULL),
"WSAIoctl()");
// load GetAcceptExSockAddrs() function
GUID GuidGetAcceptExSockAddrs = WSAID_GETACCEPTEXSOCKADDRS;
RCF_VERIFY_SOCKETS(
0 == WSAIoctl(
acceptorFd,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&GuidGetAcceptExSockAddrs,
sizeof(GuidGetAcceptExSockAddrs),
&mlpfnGetAcceptExSockAddrs,
sizeof(mlpfnGetAcceptExSockAddrs),
&dwBytes,
NULL,
NULL),
"WsaIoctl()");
}
void TcpIocpServerTransport::close()
{
// delete iocp
iocp.reset();
// close listener socket
if (acceptorFd != -1)
{
RCF_VERIFY_SOCKETS(0 == closesocket(acceptorFd), "closesocket() failure")(acceptorFd);
acceptorFd = -1;
}
// reset queued accepts count
mQueuedAccepts = 0;
// delete all sessions
sessionMaps.clear();
}
TcpIocpServerTransport::Fd TcpIocpServerTransport::hash(Fd fd)
{
return fd % fdPartitionCount;
}
// synchronized - no shared resources
TcpIocpServerTransport::SessionStatePtr TcpIocpServerTransport::createSession(int fd)
{
SessionStatePtr sessionStatePtr( new SessionState(*this, fd) );
ProactorPtr proactorPtr( new TcpIocpProactor(*this, sessionStatePtr) );
SessionPtr sessionPtr = getSessionManager().createSession();
sessionPtr->setProactorPtr(proactorPtr);
sessionStatePtr->setSessionPtr(sessionPtr);
sessionStatePtr->setWeakThisPtr( SessionStateWeakPtr(sessionStatePtr) );
return sessionStatePtr;
}
// synchronized
bool TcpIocpServerTransport::monitorSession(SessionStatePtr sessionStatePtr)
{
ReadLock readLock(rwm); // since we're reading sessionStateMap[fd]->zombie
int fd = sessionStatePtr->getFd();
Mutex &sessionMapMutex = *sessionMaps[ hash(fd) ].first;
Lock lock(sessionMapMutex); RCF_UNUSED_VARIABLE(lock);
SessionStateMap &sessionStateMap = sessionMaps[ hash(fd) ].second;
if (sessionStateMap[fd].get() && sessionStateMap[fd]->zombie)
{
return false;
}
else if (sessionStateMap[fd].get() && sessionStateMap[fd] == sessionStatePtr)
{
return true; // we're already monitoring this session
}
else
{
RCF_ASSERT(sessionStateMap[fd].get() == NULL);
sessionStateMap[fd] = sessionStatePtr;
return true;
}
}
// synchronized
bool TcpIocpServerTransport::unmonitorSession(SessionStatePtr sessionStatePtr)
{
int fd = sessionStatePtr->getFd();
Mutex &sessionMapMutex = *sessionMaps[ hash(fd) ].first;
Lock lock(sessionMapMutex); RCF_UNUSED_VARIABLE(lock);
SessionStateMap &sessionStateMap = sessionMaps[ hash(fd) ].second;
if (sessionStateMap[fd].get())
{
sessionStateMap[fd].reset();
}
return true;
}
void TcpIocpServerTransport::closeSession(SessionStatePtr sessionStatePtr)
{
RCF_TRACE("")(sessionStatePtr->getSessionPtr().get());
int fd = sessionStatePtr->getFd();
Mutex &sessionMapMutex = *sessionMaps[ hash(fd) ].first;
Lock lock(sessionMapMutex); RCF_UNUSED_VARIABLE(lock);
SessionStateMap &sessionStateMap = sessionMaps[ hash(fd) ].second;
if (sessionStateMap[fd].get())
{
if (sessionStateMap[fd]->getState() == SessionState::Accepting)
{
InterlockedDecrement( (LONG *) &mQueuedAccepts);
if (mQueuedAccepts < mQueuedAcceptsThreshold)
{
mQueuedAcceptsCondition.notify_one();
}
}
sessionStateMap[fd] = SessionStatePtr();
//sessionStatePtr->close();
// NB: socket isn't closed until the SessionState destructor is executed, which should be when the arg to this function goes out of scope
if (sessionStatePtr->isReflecting())
{
SessionStatePtr reflectionSessionStatePtr = sessionStatePtr->getReflectionSessionStateWeakPtr().lock();
if (reflectionSessionStatePtr.get())
{
externalCloseSession(reflectionSessionStatePtr);
}
}
}
else
{
RCF_TRACE("server transport cannot close session - foreign fd");
}
}
void TcpIocpServerTransport::transition(SessionStatePtr sessionStatePtr)
{
std::vector<char> &readBuffer = sessionStatePtr->getReadBuffer();
std::size_t readBufferRemaining = sessionStatePtr->getReadBufferRemaining();
std::vector<char> &writeBuffer = sessionStatePtr->getWriteBuffer();
std::size_t writeBufferRemaining = sessionStatePtr->getWriteBufferRemaining();
OVERLAPPED *pOverlapped = static_cast<OVERLAPPED *>(sessionStatePtr.get());
switch(sessionStatePtr->getState())
{
case SessionState::Accepting:
// parse the local and remote address info
{
SOCKADDR *pLocalAddr = NULL;
SOCKADDR *pRemoteAddr = NULL;
int localAddrLen = 0;
int remoteAddrLen = 0;
mlpfnGetAcceptExSockAddrs(
&readBuffer[0],
0,
sizeof(sockaddr_in) + 16,
sizeof(sockaddr_in) + 16,
&pLocalAddr,
&localAddrLen,
&pRemoteAddr,
&remoteAddrLen);
sockaddr_in *pLocalSockAddr = reinterpret_cast<sockaddr_in *>(pLocalAddr);
sessionStatePtr->setLocalAddress( IpAddress(*pLocalSockAddr) );
sockaddr_in *pRemoteSockAddr = reinterpret_cast<sockaddr_in *>(pRemoteAddr);
sessionStatePtr->setRemoteAddress( IpAddress(*pRemoteSockAddr) );
}
InterlockedDecrement( (LONG *) &mQueuedAccepts);
if (mQueuedAccepts < mQueuedAcceptsThreshold)
{
mQueuedAcceptsCondition.notify_one();
}
// is this ip allowed?
if (isClientAddrAllowed(sessionStatePtr->getRemoteSockAddr()))
{
// associate fd with iocp
{
ReadLock lock(rwm); // since we're reading sessionStatePtr->zombie
if (sessionStatePtr->zombie)
{
closeSession(sessionStatePtr);
return;
}
else
{
int fd = sessionStatePtr->getFd();
RCF_VERIFY(1 == iocp->AssociateSocket(fd, fd), "AssociateSocket() failed")(fd);
}
}
// fake a write completion to get things moving
sessionStatePtr->setState(SessionState::WritingData);
sessionStatePtr->setWriteBufferRemaining(0);
transition(sessionStatePtr);
}
else
{
closeSession(sessionStatePtr);
}
break;
case SessionState::ReadingDataCount:
RCF_ASSERT(0 <= readBufferRemaining && readBufferRemaining <= 4);
if (readBufferRemaining == 0)
{
unsigned int packetLength = * (unsigned int *) (&readBuffer[0]);
networkToMachineOrder(&packetLength, 4, 1);
if (packetLength <= getMaxMessageLength())
{
sessionStatePtr->getReadBuffer().resize(packetLength); // TODO: configurable limit on packetLength
sessionStatePtr->setReadBufferRemaining(packetLength);
sessionStatePtr->setState( SessionState::ReadingData );
transition(sessionStatePtr);
}
else
{
closeSession(sessionStatePtr);
}
}
else if (0 < readBufferRemaining && readBufferRemaining < 4)
{
char *readPos = & readBuffer[readBuffer.size() - readBufferRemaining];
int ret = sessionStatePtr->read(readPos, readBufferRemaining);
if (ret == -1)
{
closeSession(sessionStatePtr);
}
}
else if (readBufferRemaining == 4)
{
char *readPos = & readBuffer[0];
int ret = sessionStatePtr->read(readPos, readBufferRemaining);
if (ret == -1)
{
closeSession(sessionStatePtr);
}
}
break;
case SessionState::ReadingData:
if (readBufferRemaining == 0)
{
sessionStatePtr->setState( SessionState::Ready );
getSessionManager().onReadCompleted( sessionStatePtr->getSessionPtr() );
}
else
{
RCF_ASSERT( readBufferRemaining <= readBuffer.size() );
char *readPos = & readBuffer[readBuffer.size() - readBufferRemaining];
int ret = sessionStatePtr->read(readPos, readBufferRemaining);
if (ret == -1)
{
closeSession(sessionStatePtr);
}
}
break;
case SessionState::WritingData:
RCF_ASSERT(writeBufferRemaining == 0 || writeBufferRemaining == writeBuffer.size());
if (writeBufferRemaining == 0)
{
sessionStatePtr->setState( SessionState::Ready );
getSessionManager().onWriteCompleted( sessionStatePtr->getSessionPtr() );
return;
}
if (writeBufferRemaining == writeBuffer.size())
{
sessionStatePtr->clearOverlapped();
int ret = sessionStatePtr->write(&writeBuffer[0], writeBufferRemaining);
if (ret == -1)
{
closeSession(sessionStatePtr);
}
}
break;
default:
RCF_ASSERT(0)(sessionStatePtr->getState())(sessionStatePtr->getFd());
}
}
void TcpIocpServerTransport::reflectSession(SessionStatePtr sessionStatePtr, DWORD bytesRead, ULONG_PTR completionKey)
{
RCF_UNUSED_VARIABLE(completionKey);
int fd = sessionStatePtr->getFd();
int reflectionFd = sessionStatePtr->getReflectionFd();
SessionState::State state = sessionStatePtr->getState();
RCF_ASSERT(state == SessionState::ReadingData || state == SessionState::ReadingDataCount || state == SessionState::WritingData);
RCF_ASSERT(fd > 0 && reflectionFd > 0);
if (state == SessionState::WritingData)
{
sessionStatePtr->setState( SessionState::ReadingData );
sessionStatePtr->getReadBuffer().resize(1024); // decent STL implementation will only allocate memory once
OVERLAPPED *pOverlapped = sessionStatePtr.get();
WSAOVERLAPPED *pWsaOverlapped = reinterpret_cast<WSAOVERLAPPED *>(pOverlapped);
u_long len = static_cast<u_long>(sessionStatePtr->getReadBuffer().size());
char *buf = &sessionStatePtr->getReadBuffer()[0];
WSABUF wsabuf = {len, buf};
DWORD dwReceived = 0;
DWORD dwFlags = 0;
int ret = WSARecv(fd, &wsabuf, 1, &dwReceived, &dwFlags, pWsaOverlapped, NULL);
int err = WSAGetLastError();
if (ret == -1 && err != WSA_IO_PENDING)
{
RCF2_TRACE("")(ret)(err);
closeSession(sessionStatePtr);
}
}
else if (state == SessionState::ReadingData || state == SessionState::ReadingDataCount)
{
sessionStatePtr->setState( SessionState::WritingData );
OVERLAPPED *pOverlapped = sessionStatePtr.get();
WSAOVERLAPPED *pWsaOverlapped = reinterpret_cast<WSAOVERLAPPED *>(pOverlapped);
WSABUF wsabuf = { bytesRead, (char *) &sessionStatePtr->getReadBuffer()[0] };
DWORD dwSent = 0;
DWORD dwFlags = 0;
int ret = WSASend(reflectionFd, &wsabuf, 1, &dwSent, dwFlags, pWsaOverlapped, NULL);
int err = WSAGetLastError();
if (ret == -1 && err != WSA_IO_PENDING)
{
RCF2_TRACE("")(ret)(err);
closeSession(sessionStatePtr);
}
}
}
bool TcpIocpServerTransport::cycleAccepts(int timeoutMs, const volatile bool &stopFlag)
{
if (timeoutMs == 0)
{
generateAccepts();
}
else
{
Lock lock(mQueuedAcceptsMutex);
if (!stopFlag && !mStopFlag)
{
mQueuedAcceptsCondition.wait(lock);
if (!stopFlag && !mStopFlag)
{
generateAccepts();
}
else
{
return true;
}
}
}
return stopFlag || mStopFlag;
}
void TcpIocpServerTransport::stopAccepts()
{
mStopFlag = true;
Lock lock(mQueuedAcceptsMutex);
mQueuedAcceptsCondition.notify_one();
}
void TcpIocpServerTransport::generateAccepts()
{
if (mQueuedAccepts < mQueuedAcceptsThreshold)
{
for (unsigned int i=0; i<mQueuedAcceptsAugment; i++)
{
Fd fd = static_cast<Fd>(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
RCF_VERIFY_SOCKETS(fd != -1, "socket() failed");
SessionStatePtr sessionStatePtr = createSession(fd);
if (monitorSession(sessionStatePtr))
{
sessionStatePtr->getReadBuffer().resize(2*(sizeof(sockaddr_in) + 16));
sessionStatePtr->setReadBufferRemaining( 2*(sizeof(sockaddr_in) + 16) );
DWORD dwBytes = 0;
for (unsigned int i=0; i<2*(sizeof(sockaddr_in) + 16); i++)
{
sessionStatePtr->getReadBuffer()[i] = 0;
}
sessionStatePtr->clearOverlapped();
BOOL ret = mlpfnAcceptEx(
acceptorFd,
fd,
&sessionStatePtr->getReadBuffer()[0],
0,
sizeof(sockaddr_in) + 16,
sizeof(sockaddr_in) + 16,
&dwBytes,
static_cast<OVERLAPPED *>(sessionStatePtr.get()));
int err = WSAGetLastError();
if (ret == FALSE && err == ERROR_IO_PENDING)
{
// async accept initiated successfully
}
else if (dwBytes > 0)
{
RCF_ASSERT(0);
transition(sessionStatePtr);
}
else
{
RCF_THROW(ServerTransportException, "AcceptEx failed")(err);
}
BOOST_STATIC_ASSERT( sizeof(LONG) == sizeof(mQueuedAccepts) );
InterlockedIncrement( (LONG *) &mQueuedAccepts);
}
else
{
RCF2_TRACE("");
}
}
}
}
void TcpIocpServerTransport::cycle(int timeoutMs, const volatile bool &stopFlag)
{
RCF_UNUSED_VARIABLE(stopFlag);
if (mQueuedAccepts < mQueuedAcceptsThreshold)
{
mQueuedAcceptsCondition.notify_one();
}
// extract a completed io operation from the iocp
DWORD dwMilliseconds = timeoutMs < 0 ? INFINITE : timeoutMs;
DWORD dwNumBytes = 0;
ULONG_PTR completionKey = 0;
OVERLAPPED *pOverlapped = 0;
BOOL ret = iocp->GetStatus(&completionKey, &dwNumBytes, &pOverlapped, dwMilliseconds);
DWORD dwErr = GetLastError();
if (pOverlapped == NULL)
{
if (dwErr == WAIT_TIMEOUT)
{
// timed out
}
else
{
// Bad call to GetQueuedCompletionStatus, dwErr contains the reason for the bad call
RCF_ASSERT(0)(dwErr);
}
}
else
{
SessionState *pSessionState = static_cast<SessionState *>(pOverlapped);
int fd = pSessionState->getFd(); // if we're completing an AcceptEx call, this won't be the same as completion key
SessionStatePtr sessionStatePtr = getSessionStatePtr(fd);
if (sessionStatePtr.get() == NULL)
{
// somebody removed the session from the server transport
}
else if ( (ret == 0 && dwErr == WSA_OPERATION_ABORTED) || ret == 0)
{
closeSession(sessionStatePtr);
}
else
{
ReadLock lock(rwm);
if (sessionStatePtr->zombie)
{
closeSession(sessionStatePtr);
}
else
{
lock.unlock();
RCF_ASSERT(sessionStatePtr.get());
if (completionKey == acceptorFd)
{
// accept completed
onAcceptCompleted(sessionStatePtr);
}
else if (dwNumBytes == 0)
{
// socket was closed, probably by the remote host
closeSession(sessionStatePtr);
}
else if (sessionStatePtr->isReflecting())
{
// read or write completed on a reflection pair
reflectSession(sessionStatePtr, dwNumBytes, completionKey);
}
else
{
// read or write completed
int bytesRead = dwNumBytes;
sessionStatePtr->onReadWriteCompleted(bytesRead, 0);
}
}
}
}
}
void TcpIocpServerTransport::onReadWriteCompleted(SessionStatePtr sessionStatePtr, std::size_t bytesTransferred, int error)
{
if (sessionStatePtr->getState() == SessionState::ReadingData ||
sessionStatePtr->getState() == SessionState::ReadingDataCount)
{
sessionStatePtr->setReadBufferRemaining( sessionStatePtr->getReadBufferRemaining() - bytesTransferred);
transition(sessionStatePtr);
}
else if (sessionStatePtr->getState() == SessionState::WritingData)
{
// TODO: handle partially completed writes
RCF_ASSERT(bytesTransferred == sessionStatePtr->getWriteBufferRemaining());
sessionStatePtr->setWriteBufferRemaining(0);
transition(sessionStatePtr);
}
}
void TcpIocpServerTransport::onAcceptCompleted(SessionStatePtr sessionStatePtr)
{
transition(sessionStatePtr);
}
void TcpIocpServerTransport::postWrite(SessionStatePtr sessionStatePtr)
{
RCF_ASSERT(sessionStatePtr->getState() == SessionState::Ready);
RCF_ASSERT(sizeof(unsigned int) == 4);
RCF_ASSERT(sessionStatePtr->getWriteBuffer().size() > 4);
sessionStatePtr->setState( SessionState::WritingData );
sessionStatePtr->setWriteBufferRemaining( static_cast<unsigned int>(sessionStatePtr->getWriteBuffer().size()) );
RCF_ASSERT(sessionStatePtr->getWriteBuffer().size() >= 4);
*(unsigned int*) &sessionStatePtr->getWriteBuffer()[0] = static_cast<unsigned int>(sessionStatePtr->getWriteBuffer().size()-4);
RCF::machineToNetworkOrder(&sessionStatePtr->getWriteBuffer()[0], 4, 1);
transition(sessionStatePtr);
}
void TcpIocpServerTransport::postRead(SessionStatePtr sessionStatePtr)
{
RCF_ASSERT(sessionStatePtr->getState() == SessionState::Ready);
sessionStatePtr->setState(SessionState::ReadingDataCount);
sessionStatePtr->getReadBuffer().resize(4);
sessionStatePtr->setReadBufferRemaining(4);
transition(sessionStatePtr);
}
// synchronized
TcpIocpServerTransport::SessionStatePtr TcpIocpServerTransport::getSessionStatePtr(Fd fd)
{
Lock lock( *sessionMaps[hash(fd)].first );
return sessionMaps[hash(fd)].second[fd];
}
void TcpIocpServerTransport::externalCloseSession(SessionStatePtr sessionStatePtr)
{
int fd = sessionStatePtr->getFd();
externalCloseSession(sessionStatePtr, fd);
}
void TcpIocpServerTransport::externalCloseSession(SessionStatePtr sessionStatePtr, int &fd)
{
WriteLock lock(rwm); // since we're writing sessionStatePtr->zombie
if (sessionStatePtr.get())
{
if (!sessionStatePtr->zombie)
{
RCF_VERIFY_SOCKETS(0 == closesocket(sessionStatePtr->getFd()), "closesocket() failure")(fd)(sessionStatePtr->getFd());
sessionStatePtr->zombie = true;
}
fd = -1;
}
else
{
RCF_ASSERT(0);
RCF_VERIFY_SOCKETS(0 == closesocket(fd), "closesocket() failure")(sessionStatePtr->getFd());
fd = -1;
}
// actual removal of the the SessionState doesn't occur until the session receives a close event from the iocp.
// generally that would be the last reference to the SessionState, and so the SessionState would then also be deallocated.
}
// alias for boost.bind
void TcpIocpServerTransport::externalCloseSession0(SessionStatePtr sessionStatePtr, int &fd)
{
externalCloseSession(sessionStatePtr, fd);
}
// create a server-aware client transport on the connection associated with this session. fd is owned by the client, not the server session.
// will only create a client transport the first time it is called, after that an empty auto_ptr is returned.
ClientTransportAutoPtr TcpIocpServerTransport::createClientTransport(boost::shared_ptr<I_Session> sessionPtr)
{
ProactorPtr proactorPtr = sessionPtr->getProactorPtr();
TcpIocpProactor &tcpIocpProactor = dynamic_cast<TcpIocpProactor &>(*proactorPtr);
SessionStatePtr sessionStatePtr = tcpIocpProactor.getSessionStatePtr();
std::auto_ptr<TcpClientTransport> clientTransport;
WriteLock lock(rwm); // since we're reading and writing sessionStatePtr->ownFd
if (sessionStatePtr->ownFd)
{
int fd = sessionStatePtr->getFd();
sessionStatePtr->ownFd = false;
clientTransport.reset( new TcpClientTransport(fd) );
clientTransport->setRemoteAddr( sessionStatePtr->getRemoteSockAddr() );
clientTransport->setCloseFunctor( boost::bind(&TcpIocpServerTransport::externalCloseSession0, this, sessionStatePtr, _1) );
}
return ClientTransportAutoPtr(clientTransport.release());
}
// create a server session on the connection associated with the client transport
boost::shared_ptr<I_Session> TcpIocpServerTransport::createServerSession(ClientTransportAutoPtr clientTransportAutoPtr)
{
TcpClientTransport &tcpClientTransport = dynamic_cast<TcpClientTransport &>(*clientTransportAutoPtr);
int fd = tcpClientTransport.releaseFd();
RCF_ASSERT(fd > 0);
SessionStatePtr sessionStatePtr = createSession(fd);
sessionStatePtr->setRemoteAddress( IpAddress(tcpClientTransport.getRemoteAddr()) );
sessionStatePtr->setState(SessionState::WritingData);
sessionStatePtr->setWriteBufferRemaining(0);
if (!monitorSession(sessionStatePtr))
{
return boost::shared_ptr<I_Session>();
}
else
{
{
ReadLock lock(rwm); // since we're reading sessionStatePtr->zombie
if (!sessionStatePtr->zombie)
{
int fd = sessionStatePtr->getFd();
RCF_VERIFY_SOCKETS(1 == iocp->AssociateSocket(fd, fd), "Iocp::AssociateSocket() failed()");
}
}
transition(sessionStatePtr);
return sessionStatePtr->getSessionPtr();
}
}
// start reflecting data between the two given sessions
bool TcpIocpServerTransport::reflect(boost::shared_ptr<I_Session> sessionPtr1, boost::shared_ptr<I_Session> sessionPtr2)
{
ProactorPtr proactorPtr1 = sessionPtr1->getProactorPtr();
ProactorPtr proactorPtr2 = sessionPtr2->getProactorPtr();
TcpIocpProactor &tcpIocpProactor1 = dynamic_cast<TcpIocpProactor &>(*proactorPtr1);
TcpIocpProactor &tcpIocpProactor2 = dynamic_cast<TcpIocpProactor &>(*proactorPtr2);
return reflect(tcpIocpProactor1.getSessionStatePtr(), tcpIocpProactor2.getSessionStatePtr());
}
bool TcpIocpServerTransport::reflect(SessionStatePtr sessionStatePtr1, SessionStatePtr sessionStatePtr2)
{
RCF_ASSERT(sessionStatePtr1.get() && sessionStatePtr2.get())(sessionStatePtr1.get())(sessionStatePtr2.get());
if (monitorSession(sessionStatePtr1) && monitorSession(sessionStatePtr2))
{
WriteLock lock(rwm); // what sync do we need here?
sessionStatePtr1->setReflectionSessionStateWeakPtr( SessionStateWeakPtr(sessionStatePtr2) );
sessionStatePtr1->setReflectionFd( sessionStatePtr2->getFd() );
sessionStatePtr2->setReflectionSessionStateWeakPtr( SessionStateWeakPtr(sessionStatePtr1) );
sessionStatePtr2->setReflectionFd( sessionStatePtr1->getFd() );
// all pending events for sessionStatePtr1 and sessionStatePtr2 will now be reflected
return true;
}
return false;
}
// check if a server session is still connected
bool TcpIocpServerTransport::isConnected(boost::shared_ptr<I_Session> sessionPtr)
{
ProactorPtr proactorPtr = sessionPtr->getProactorPtr();
TcpIocpProactor &tcpIocpProactor = dynamic_cast<TcpIocpProactor &>(*proactorPtr);
SessionStatePtr sessionStatePtr = tcpIocpProactor.getSessionStatePtr();
ReadLock lock(rwm); // since we're reading sessionStatePtr->zombie
return sessionStatePtr->getFd() > 0 && !sessionStatePtr->zombie;
}
// create a server-aware client transport to given endpoint
ClientTransportAutoPtr TcpIocpServerTransport::createClientTransport(const I_Endpoint &endpoint)
{
const TcpEndpoint &tcpEndpoint = dynamic_cast<const TcpEndpoint &>(endpoint);
return std::auto_ptr<I_ClientTransport>( new TcpClientTransport(tcpEndpoint.getIp(), tcpEndpoint.getPort()) );
}
void TcpIocpServerTransport::setSessionManager(I_SessionManager &sessionManager)
{
pSessionManager = &sessionManager;
}
I_SessionManager &TcpIocpServerTransport::getSessionManager()
{
return *pSessionManager;
}
bool TcpIocpServerTransport::cycleTransportAndServer(RcfServer &server, int timeoutMs, const volatile bool &stopFlag)
{
if (!stopFlag && !mStopFlag)
{
cycle(timeoutMs/2, stopFlag);
server.cycleSessions(timeoutMs/2, stopFlag);
}
return stopFlag || mStopFlag;
}
void TcpIocpServerTransport::onServiceAdded(RcfServer &server)
{
setSessionManager(server);
WriteLock writeLock( getTaskEntriesMutex() );
getTaskEntries().clear();
getTaskEntries().push_back(
TaskEntry(
boost::bind(&TcpIocpServerTransport::cycleTransportAndServer, this, boost::ref(server), _1, _2),
StopFunctor()));
getTaskEntries().push_back(
TaskEntry(
boost::bind(&TcpIocpServerTransport::cycleAccepts, this, _1, _2),
boost::bind(&TcpIocpServerTransport::stopAccepts, this)) );
mStopFlag = false;
}
void TcpIocpServerTransport::onServiceRemoved(RcfServer &)
{}
#ifdef _MSC_VER
#pragma warning( push )
#pragma warning( disable : 4355 ) // warning C4355: 'this' : used in base member initializer list
#endif
TcpIocpServerTransport::SessionState::SessionState(TcpIocpServerTransport &transport, Fd fd) :
state(SessionState::Accepting),
reflectionFd(),
readBufferRemaining(),
writeBufferRemaining(),
fd(fd),
ownFd(true),
zombie(),
wsaRecvFunctor(this, this->fd, zombie, transport.rwm),
wsaSendFunctor(this, this->fd, zombie, transport.rwm),
transport(transport)
{
// blank the OVERLAPPED structure
clearOverlapped();
}
#ifdef _MSC_VER
#pragma warning( pop )
#endif
void TcpIocpServerTransport::onServerStart(RcfServer &)
{
if (!mOpen)
{
open();
mOpen = true;
}
}
void TcpIocpServerTransport::onServerStop(RcfServer &)
{
if (mOpen)
{
close();
mOpen = false;
mStopFlag = false;
}
}
void TcpIocpServerTransport::onServerOpen(RcfServer &)
{
if (!mOpen)
{
open();
mOpen = true;
}
}
void TcpIocpServerTransport::onServerClose(RcfServer &)
{
if (mOpen)
{
close();
mOpen = false;
}
}
} // namespace RCF