//*****************************************************************************
// RCF - Remote Call Framework
// Copyright (c) 2005. All rights reserved.
// Developed by Jarl Lindrud.
// Contact: jlindrud@hotmail.com .
//*****************************************************************************
#include <RCF/RcfServer.hpp>
#include <algorithm>
#include <boost/bind.hpp>
#include <boost/lexical_cast.hpp>
#include <RCF/AsyncFilter.hpp>
#include <RCF/CurrentSession.hpp>
#include <RCF/EncodeMessage.hpp>
#include <RCF/Endpoint.hpp>
#include <RCF/MethodInvocation.hpp>
#include <RCF/RcfClient.hpp>
#include <RCF/ServerTask.hpp>
#include <RCF/Service.hpp>
#include <RCF/Session.hpp>
#include <RCF/StubEntry.hpp>
#include <RCF/Token.hpp>
#include <RCF/Tools.hpp>
namespace RCF {
void repeatTask(Task task, int timeoutMs, const volatile bool &stopFlag)
{
RCF_TRACE("");
while (!stopFlag && !task(timeoutMs, stopFlag));
RCF_TRACE("")(stopFlag);
}
void repeatCycleServer(RcfServer &server, int timeoutMs)
{
RCF_TRACE("");
while (server.cycle(timeoutMs) == false);
RCF_TRACE("");
}
// RcfServer
RcfServer::RcfServer(const I_Endpoint &endpoint) :
mOpened(),
mStarted(),
mServerThreadsStopFlag(),
mStubMapMutex(WriterPriority),
mServicesMutex(WriterPriority)
{
RCF_TRACE("");
ServerTransportPtr serverTransportPtr( endpoint.createServerTransport().release() );
ServicePtr servicePtr( boost::dynamic_pointer_cast<I_Service>(serverTransportPtr) );
addService(servicePtr);
}
RcfServer::RcfServer(ServicePtr servicePtr) :
mOpened(),
mStarted(),
mServerThreadsStopFlag(),
mStubMapMutex(WriterPriority),
mServicesMutex(WriterPriority)
{
RCF_TRACE("");
addService(servicePtr);
}
RcfServer::RcfServer(ServerTransportPtr serverTransportPtr) :
mOpened(),
mStarted(),
mServerThreadsStopFlag(),
mStubMapMutex(WriterPriority),
mServicesMutex(WriterPriority)
{
RCF_TRACE("");
addService( boost::dynamic_pointer_cast<I_Service>(serverTransportPtr) );
}
RcfServer::~RcfServer()
{
RCF_TRACE("");
close();
}
bool RcfServer::addService(ServicePtr servicePtr)
{
RCF_TRACE("")(typeid(*servicePtr).name());
bool ret = false;
{
WriteLock writeLock(mServicesMutex);
if (std::find(mServices.begin(), mServices.end(), servicePtr) == mServices.end())
{
mServices.push_back(servicePtr);
ret = true;
StubEntryLookupProviderPtr stubEntryLookupProviderPtr = boost::dynamic_pointer_cast<I_StubEntryLookupProvider>(servicePtr);
if (stubEntryLookupProviderPtr)
{
mStubEntryLookupProviders.push_back(stubEntryLookupProviderPtr);
}
FilterFactoryLookupProviderPtr filterFactoryLookupProviderPtr = boost::dynamic_pointer_cast<I_FilterFactoryLookupProvider>(servicePtr);
if (filterFactoryLookupProviderPtr)
{
mFilterFactoryLookupProviders.push_back(filterFactoryLookupProviderPtr);
}
ServerTransportPtr serverTransportPtr = boost::dynamic_pointer_cast<I_ServerTransport>(servicePtr);
if (serverTransportPtr)
{
mServerTransports.push_back(serverTransportPtr);
}
}
}
if (ret)
{
servicePtr->onServiceAdded(*this);
}
{
Lock lock(mStartedMutex);
if (mStarted)
{
startService(servicePtr);
}
}
return ret;
}
bool RcfServer::removeService(ServicePtr servicePtr)
{
RCF_TRACE("")(typeid(*servicePtr).name());
bool found = false;
{
WriteLock writeLock(mServicesMutex);
std::vector<ServicePtr>::iterator iter = std::find(mServices.begin(), mServices.end(), servicePtr);
if (iter != mServices.end())
{
stopService(*iter);
mServices.erase(iter);
found = true;
}
}
if (found)
{
servicePtr->onServiceRemoved(*this);
}
return found;
}
void RcfServer::open()
{
RCF_TRACE("");
Lock lock(mOpenedMutex);
if (!mOpened)
{
std::vector<ServicePtr> services;
{
ReadLock readLock(mServicesMutex);
std::copy(mServices.begin(), mServices.end(), std::back_inserter(services));
}
for (unsigned int i=0; i<services.size(); ++i)
{
services[i]->onServerOpen(*this);
}
mOpened = true;
}
}
void RcfServer::start(bool spawnThreads /*= true*/)
{
RCF_TRACE("");
Lock lock(mStartedMutex);
if (!mStarted)
{
mServerThreadsStopFlag = false;
// open the server
open();
// make a local copy of the service table
std::vector<ServicePtr> services;
{
ReadLock readLock(mServicesMutex);
std::copy(mServices.begin(), mServices.end(), std::back_inserter(services));
}
// notify all services
for (unsigned int i=0; i<services.size(); ++i)
{
services[i]->onServerStart(*this);
}
// spawn internal worker threads
if (spawnThreads)
{
for (unsigned int i=0; i<services.size(); ++i)
{
startService(services[i]);
}
}
mStarted = true;
// call the start notification callback, if there is one
invokeStartCallback();
// notify anyone who was waiting on the stop event
mStartEvent.notify_all();
}
}
void RcfServer::addJoinFunctor(JoinFunctor joinFunctor)
{
if (joinFunctor)
{
mJoinFunctors.push_back(joinFunctor);
}
}
void RcfServer::startInThisThread()
{
startInThisThread(JoinFunctor());
}
void RcfServer::startInThisThread(JoinFunctor joinFunctor)
{
RCF_TRACE("");
start();
// register the join functor
mJoinFunctors.push_back(joinFunctor);
// run all tasks sequentially in this thread
repeatCycleServer(*this, 500);
}
bool RcfServer::cycle(int timeoutMs)
{
RCF_TRACE("")(timeoutMs);
// sequentially run each task
// only first task is allowed to use the timeout
// if tasks are being dynamically added or removed, a given task might be cycled twice or not at all
unsigned int i=0;
while (true)
{
ServicePtr servicePtr;
{
ReadLock readLock(mServicesMutex);
if (i < mServices.size())
{
servicePtr = mServices[i];
}
}
if (servicePtr)
{
unsigned int j=0;
while (true)
{
Task task;
bool ok = false;
{
ReadLock readLock(servicePtr->getTaskEntriesMutex());
TaskEntries &taskEntries = servicePtr->getTaskEntries();
if (j < taskEntries.size())
{
task = taskEntries[j].mTask;
ok = true;
}
}
if (ok)
{
task(
i == 0 && j == 0 ? timeoutMs : 0,
mServerThreadsStopFlag);
++j;
}
else
{
break;
}
}
++i;
}
else
{
break;
}
}
return mServerThreadsStopFlag;
}
void RcfServer::startService(ServicePtr servicePtr)
{
RCF_TRACE("")(typeid(*servicePtr));
WriteLock writeLock(servicePtr->getTaskEntriesMutex());
TaskEntries &taskEntries = servicePtr->getTaskEntries();
for (unsigned int j=0; j<taskEntries.size(); ++j)
{
TaskEntry &taskEntry = taskEntries[j];
taskEntry.mThreadPtr = ThreadPtr( new Thread( boost::bind( repeatTask, taskEntry.mTask, 1000, boost::ref(mServerThreadsStopFlag)) ) );
}
}
void RcfServer::stopService(ServicePtr servicePtr, bool wait /*= true*/)
{
RCF_TRACE("")(typeid(*servicePtr))(wait);
WriteLock writeLock(servicePtr->getTaskEntriesMutex());
TaskEntries &taskEntries = servicePtr->getTaskEntries();
for (unsigned int j=0; j<taskEntries.size(); ++j)
{
TaskEntry &taskEntry = taskEntries[j];
if (taskEntry.mStopFunctor)
{
taskEntry.mStopFunctor();
}
if (wait && taskEntry.mThreadPtr.get())
{
taskEntry.mThreadPtr->join();
taskEntry.mThreadPtr.reset();
}
}
}
void RcfServer::stop(bool wait /*= true*/)
{
RCF_TRACE("")(wait);
Lock lock(mStartedMutex);
if (mStarted)
{
// set stop flag
mServerThreadsStopFlag = true;
// make a local copy of the service table
std::vector<ServicePtr> services;
{
ReadLock readLock(mServicesMutex);
std::copy(mServices.begin(), mServices.end(), std::back_inserter(services));
}
// notify and optionally join all internal worker threads
for (unsigned int i=0; i<services.size(); ++i)
{
stopService(services[i], wait);
}
if (wait)
{
// join all external worker threads
for (unsigned int i=0; i<mJoinFunctors.size(); ++i)
{
if (mJoinFunctors[i])
{
mJoinFunctors[i]();
}
}
mJoinFunctors.clear();
// notify all services
for (unsigned int i=0; i<services.size(); ++i)
{
services[i]->onServerStop(*this);
}
// clear stop flag, since all the threads have been joined
mServerThreadsStopFlag = false;
mStarted = false;
// notify anyone who was waiting on the stop event
mStopEvent.notify_all();
}
}
}
void RcfServer::close()
{
RCF_TRACE("");
Lock lock(mOpenedMutex);
if (mOpened)
{
// stop the server
stop();
std::vector<ServicePtr> services;
{
ReadLock readLock(mServicesMutex);
std::copy(mServices.begin(), mServices.end(), std::back_inserter(services));
}
for (unsigned int i=0; i<services.size(); ++i)
{
services[i]->onServerClose(*this);
}
// set status
mOpened = false;
}
}
void RcfServer::waitForStopEvent()
{
RCF_TRACE("");
Lock lock(mStartedMutex);
mStopEvent.wait(lock);
}
void RcfServer::waitForStartEvent()
{
RCF_TRACE("");
Lock lock(mStartedMutex);
mStartEvent.wait(lock);
}
boost::shared_ptr<I_Session> RcfServer::createSession()
{
RCF_TRACE("");
return boost::shared_ptr<I_Session>(new Session());
}
struct SharedPtrIsNull
{
template<typename T>
bool operator()(boost::shared_ptr<T> spt)
{
return spt.get() == NULL;
}
};
void RcfServer::onReadCompleted(boost::shared_ptr<I_Session> sessionPtr)
{
// 1. Deserialize request data
// 2. Store request data in session
// 3. Move session to corresponding queue
RCF_TRACE("")(sessionPtr.get());
Session &session = static_cast<Session &>(*sessionPtr);
std::vector<char> &readBuffer = session.getProactorPtr()->getReadBuffer();
std::size_t readOffset = session.getProactorPtr()->getReadOffset();
std::size_t dataLength = static_cast<unsigned int>(readBuffer.size()) - readOffset;
char *pData = &readBuffer[readOffset];
RCF_ASSERT( dataLength > 1 );
std::string message(pData, dataLength);
std::string filteredData;
std::string unfilteredData;
unsigned int unfilteredDataLen = 0;
int protocol = 0;
std::vector<int> filterIds;
RCF_VERIFY(true == decodeMessage(message, filteredData, unfilteredDataLen, protocol, filterIds), "decodeMessage()");
if (filterIds.empty())
{
session.filtered = false;
session.in.reset(filteredData, protocol);
}
else
{
session.filtered = true;
std::vector<FilterPtr> &filters = session.filters;
if (filters.size() != filterIds.size() || !std::equal(filters.begin(), filters.end(), filterIds.begin(), FilterIdComparison()))
{
filters.clear();
std::transform(filterIds.begin(), filterIds.end(), std::back_inserter(filters), boost::bind( &RcfServer::createFilter, this, _1) );
if (std::find_if(filters.begin(), filters.end(), SharedPtrIsNull()) == filters.end())
{
connectFilters(filters);
}
else
{
RCF_THROW(ServerException, "could not create filter"); // TODO: better not to throw exceptions here
}
}
RCF_VERIFY( true == unfilterData(filteredData, unfilteredData, unfilteredDataLen, filters), "encodeMessage()");
session.in.reset(unfilteredData, protocol);
}
session.out.reset(protocol);
session.in >> session.getRequest();
// Place it in a queue for a worker thread to process
RCF_TRACE("pushing session onto queue")(sessionPtr.get());
getSessionQueue(session).push_back( boost::static_pointer_cast<Session>(sessionPtr) );
}
void RcfServer::onWriteCompleted(boost::shared_ptr<I_Session> sessionPtr)
{
RCF_TRACE("")(sessionPtr.get());
SessionPtr mySessionPtr = boost::static_pointer_cast<Session>(sessionPtr);
mySessionPtr->onWriteCompleted();
mySessionPtr->getProactorPtr()->postRead();
}
void RcfServer::sendSessionResponse(SessionPtr sessionPtr)
{
RCF_TRACE("")(sessionPtr.get());
Session &session = *sessionPtr;
if (session.getRequest().getOneway())
{
//session.getProactor().postRead();
onWriteCompleted(sessionPtr);
}
else
{
std::string message;
const std::vector<FilterPtr> &filters = session.filters;
std::string unfilteredData = session.out.str();
int protocol = session.out.getSerializationProtocol();
if (session.filtered)
{
RCF_VERIFY(true == encodeMessage(message, unfilteredData, protocol, filters), "encodeMessage()");
}
else
{
RCF_VERIFY(true == encodeMessage(message, unfilteredData, protocol, std::vector<FilterPtr>()), "encodeMessage()");
}
std::vector<char> &writeBuffer = session.getProactorPtr()->getWriteBuffer();
std::size_t writeOffset = session.getProactorPtr()->getWriteOffset();
writeBuffer.resize(writeOffset+message.size());
memcpy(&writeBuffer[writeOffset], message.c_str(), message.size());
session.getProactorPtr()->postWrite();
}
}
void RcfServer::closeSession(SessionPtr sessionPtr)
{
sessionPtr->getProactorPtr()->postClose();
}
void RcfServer::serializeSessionExceptionResponse(SessionPtr sessionPtr)
{
RCF_TRACE("")(sessionPtr.get());
sessionPtr->out.reset();
sessionPtr->out << RCF::MethodInvocationResponse(true);
sessionPtr->out << std::string("std::runtime_error");
sessionPtr->out << std::string("non-std::exception derived exception was thrown upon invoking server object");
}
class SetCurrentSessionGuard
{
public:
SetCurrentSessionGuard(SessionPtr sessionPtr)
{
setCurrentSession(sessionPtr);
}
~SetCurrentSessionGuard()
{
setCurrentSession();
}
};
void RcfServer::handleSession(SessionPtr sessionPtr)
{
RCF_TRACE("")(sessionPtr.get());
SetCurrentSessionGuard setCurrentSessionGuard(sessionPtr);
StubEntryAddRef &stubEntryAddRef = sessionPtr->getStubEntryAddRef();
MethodInvocationRequest &request = sessionPtr->getRequest();
if (request.getClose())
{
sessionPtr->getProactorPtr()->postClose();
return;
}
Token token = request.getToken();
StubEntryPtr stubEntryPtr;
if (sessionPtr->hasServerStub())
{
stubEntryPtr = sessionPtr->getStubEntryPtr();
}
else if (token == Token())
{
ReadLock readLock(mStubMapMutex);
std::string servantName = request.getService();
StubMap::iterator iter = mStubMap.find(servantName);
if (iter != mStubMap.end())
{
stubEntryPtr = (*iter).second;
}
}
else
{
if (!mStubEntryLookupProviders.empty())
{
stubEntryPtr = mStubEntryLookupProviders[0]->getStubEntryPtr(token);
}
}
{
// NB: the following scopeguard's are apparently not triggered by Borland C++, when throwing non
// std::exception derived exceptions.
// Separate scopes for the guards to guarantee order of destruction (overkill?)
::ScopeGuard sendResponseGuard =
MakeObjGuard(*this, &RcfServer::sendSessionResponse, sessionPtr);
{
::ScopeGuard serializeExceptionResponseGuard =
MakeObjGuard(*this, &RcfServer::serializeSessionExceptionResponse, sessionPtr) ;
{
::ScopeGuard closeSessionGuard =
MakeObjGuard(*this, &RcfServer::closeSession, sessionPtr);
try
{
if (NULL == stubEntryPtr.get())
{
RCF_THROW(ServerException, "no server stub entry")(request.getService())(request.getSubInterface())(request.getFnId()); // TODO: exception type
}
else
{
stubEntryAddRef( stubEntryPtr );
stubEntryPtr->getRcfClientPtr()->getServerStub().invoke(request.getSubInterface(), request.getFnId(), sessionPtr->in, sessionPtr->out);
serializeExceptionResponseGuard.Dismiss();
closeSessionGuard.Dismiss();
}
}
catch(const SerializationException &e)
{
RCF_TRACE(": Serialization exception")(typeid(e))(e);
serializeExceptionResponseGuard.Dismiss();
sendResponseGuard.Dismiss();
}
catch(const std::exception &e)
{
RCF_TRACE(": User exception")(typeid(e))(e);
serializeExceptionResponseGuard.Dismiss();
closeSessionGuard.Dismiss();
sessionPtr->out.reset();
sessionPtr->out << RCF::MethodInvocationResponse(true);
sessionPtr->out << std::string(typeid(e).name());
sessionPtr->out << std::string(e.what());
}
}
}
}
}
SessionQueue &RcfServer::getSessionQueue(Session &session)
{
// Return the session queue in which a session should be placed. For now, that's
// the thread-specific session queue.
RCF_UNUSED_VARIABLE(session);
if (mThreadSpecificSessionQueuePtr.get() == NULL)
{
mThreadSpecificSessionQueuePtr.reset(new SessionQueue);
}
return *mThreadSpecificSessionQueuePtr;
}
void RcfServer::cycleSessions(int timeoutMs, const volatile bool &stopFlag)
{
RCF_TRACE("")(timeoutMs);
if (mThreadSpecificSessionQueuePtr.get() == NULL)
{
mThreadSpecificSessionQueuePtr.reset(new SessionQueue);
}
while (!stopFlag && !mThreadSpecificSessionQueuePtr->empty())
{
SessionPtr sessionPtr = mThreadSpecificSessionQueuePtr->back();
mThreadSpecificSessionQueuePtr->pop_back();
handleSession(sessionPtr);
}
}
I_ServerTransport &RcfServer::getServerTransport()
{
return *getServerTransportPtr();
}
boost::shared_ptr<I_ServerTransport> RcfServer::getServerTransportPtr()
{
ReadLock readLock( mServicesMutex );
RCF_ASSERT( ! mServerTransports.empty() );
return mServerTransports[0];
}
bool RcfServer::bindShared(const std::string &name, RcfClientPtr rcfClientPtr)
{
WriteLock writeLock(mStubMapMutex);
if (NULL == mStubMap[name].get())
{
StubEntryPtr stubEntryPtr( new StubEntry );
stubEntryPtr->setRcfClientPtr(rcfClientPtr);
mStubMap[name] = stubEntryPtr;
return true;
}
return false;
}
FilterPtr RcfServer::createFilter(int filterId)
{
return mFilterFactoryLookupProviders.empty() ?
FilterPtr() :
mFilterFactoryLookupProviders[0]->getFilterFactoryPtr(filterId)->createFilter();
}
void RcfServer::setStartCallback(StartCallback startCallback)
{
startCallback_ = startCallback;
}
void RcfServer::invokeStartCallback()
{
if (startCallback_)
{
startCallback_(*this);
}
}
bool RcfServer::getStopFlag()
{
return mServerThreadsStopFlag;
}
} // namespace RCF