//*****************************************************************************
// RCF - Remote Call Framework
// Copyright (c) 2005. All rights reserved.
// Developed by Jarl Lindrud.
// Contact: jlindrud@hotmail.com .
//*****************************************************************************
#include <RCF/TcpClientTransport.hpp>
#include <RCF/TcpEndpoint.hpp>
#include <RCF/TimedBsdSockets.hpp>
namespace RCF {
BsdRecvFunctor::BsdRecvFunctor() : fd(), endTimeMs()
{}
void BsdRecvFunctor::setFd(int fd)
{
this->fd = fd;
}
int BsdRecvFunctor::getFd()
{
return fd;
}
void BsdRecvFunctor::setEndTimeMs(unsigned int endTimeMs)
{
this->endTimeMs = endTimeMs;
}
unsigned int BsdRecvFunctor::getEndTimeMs()
{
return endTimeMs;
}
BsdRecvFunctor::Status BsdRecvFunctor::getStatus()
{
return status;
}
void BsdRecvFunctor::operator()(char *buffer, std::size_t bufferLen)
{
unsigned int timeoutMs = generateTimeoutMs(endTimeMs);
int ret = timedRecv(timeoutMs, fd, buffer, bufferLen, 0);
if (ret == -2)
{
status = TimeOut;
readWriteCompletionCallback(0, -1);
}
else if (ret == -1)
{
status = SocketError;
readWriteCompletionCallback(0, -1);
}
else if (ret == 0)
{
status = ConnectionReset;
readWriteCompletionCallback(0, -1);
}
else
{
RCF_ASSERT(0 < ret && ret <= static_cast<int>(bufferLen));
readWriteCompletionCallback(ret, 0);
}
}
Filter::ReadFunction BsdRecvFunctor::getReadFunction()
{
return boost::bind( &BsdRecvFunctor::operator(), this, _1, _2);
}
void BsdRecvFunctor::setReadWriteCompletionCallback(Filter::ReadWriteCompletionCallback readWriteCompletionCallback)
{
this->readWriteCompletionCallback = readWriteCompletionCallback;
}
BsdSendFunctor::BsdSendFunctor() : fd(), endTimeMs()
{}
void BsdSendFunctor::setFd(int fd)
{
this->fd = fd;
}
int BsdSendFunctor::getFd()
{
return fd;
}
void BsdSendFunctor::setEndTimeMs(unsigned int endTimeMs)
{
this->endTimeMs = endTimeMs;
}
unsigned int BsdSendFunctor::getEndTimeMs()
{
return endTimeMs;
}
void BsdSendFunctor::operator()(const char *buffer, std::size_t bufferLen)
{
unsigned int timeoutMs = generateTimeoutMs(endTimeMs);
int ret = timedSend(timeoutMs, fd, buffer, bufferLen, 0);
if (ret == -2)
{
status = TimeOut;
readWriteCompletionCallback(0, -1);
}
else if (ret == -1)
{
status = SocketError;
readWriteCompletionCallback(0, -1);
}
else
{
RCF_ASSERT(0 < ret && ret <= static_cast<int>(bufferLen));
readWriteCompletionCallback(ret, 0);
}
}
Filter::WriteFunction BsdSendFunctor::getWriteFunction()
{
return boost::bind( &BsdSendFunctor::operator(), this, _1, _2);
}
void BsdSendFunctor::setReadWriteCompletionCallback(Filter::ReadWriteCompletionCallback readWriteCompletionCallback)
{
this->readWriteCompletionCallback = readWriteCompletionCallback;
}
TcpClientTransport::TcpClientTransport(const std::string &ip, int port) :
mRemoteAddr(),
ip(ip),
port(port),
fd(-1),
own(true),
mError(),
mBytesTransferred()
{
memset(&mRemoteAddr, 0, sizeof(mRemoteAddr));
setTransportFilters( std::vector<FilterPtr>() );
}
TcpClientTransport::TcpClientTransport(sockaddr_in remoteAddr) :
mRemoteAddr(remoteAddr),
ip(),
port(),
fd(-1),
own(true),
mError(),
mBytesTransferred()
{
setTransportFilters( std::vector<FilterPtr>() );
}
TcpClientTransport::TcpClientTransport(const TcpClientTransport &rhs) :
mRemoteAddr(rhs.mRemoteAddr),
ip(rhs.ip),
port(rhs.port),
fd(-1),
own(true),
mError(),
mBytesTransferred()
{
setTransportFilters( std::vector<FilterPtr>() );
}
TcpClientTransport::TcpClientTransport(int fd) :
mRemoteAddr(),
ip(),
port(),
fd(fd),
own(true),
mError(),
mBytesTransferred()
{
memset(&mRemoteAddr, 0, sizeof(mRemoteAddr));
setTransportFilters( std::vector<FilterPtr>() );
recvFunctor.setFd(fd);
sendFunctor.setFd(fd);
}
TcpClientTransport::~TcpClientTransport()
{
if (own)
{
close();
}
}
std::auto_ptr<I_ClientTransport> TcpClientTransport::clone() const
{
return ClientTransportAutoPtr( new TcpClientTransport(*this) );
}
EndpointPtr TcpClientTransport::getEndpointPtr() const
{
return EndpointPtr( new TcpEndpoint(ip, port) );
}
int TcpClientTransport::connect(unsigned int timeoutMs)
{
// TODO: replace throw with return, where possible
//if (fd == -1)
if (!isConnected())
{
// close the current connection
close();
fd = static_cast<int>( ::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP) );
RCF_VERIFY(fd != -1, "socket()");
Platform::OS::BsdSockets::setblocking(fd, false);
if (mRemoteAddr.sin_addr.s_addr == 0)
{
unsigned long ul_addr = ::inet_addr(ip.c_str());
if (ul_addr == INADDR_NONE)
{
hostent *hostDesc = ::gethostbyname( ip.c_str() );
if (hostDesc)
{
char *szIp = ::inet_ntoa( * (in_addr*) hostDesc->h_addr_list[0]);
ul_addr = ::inet_addr(szIp);
}
}
memset(&mRemoteAddr, 0, sizeof(mRemoteAddr));
mRemoteAddr.sin_family = AF_INET;
mRemoteAddr.sin_addr.s_addr = ul_addr;
//remoteAddr.sin_port = ::htons(port); // the :: seems to screw up gcc (!?!)
mRemoteAddr.sin_port = htons(port); // the :: seems to screw up gcc (!?!)
}
if (0 != timedConnect(timeoutMs, fd, (sockaddr*) &mRemoteAddr, sizeof(mRemoteAddr)))
{
fd = -1;
RCF_THROW(ClientTransportException, "Socket failure (timedConnect())")(ip)(port);
}
}
sendFunctor.setFd(fd);
recvFunctor.setFd(fd);
return 1;
}
// return -1 for error (including timeout), 1 for ok
int TcpClientTransport::send(const std::string &data, unsigned int totalTimeoutMs)
{
unsigned int startTimeMs = getCurrentTimeMs();
unsigned int endTimeMs = startTimeMs + totalTimeoutMs;
recvFunctor.setEndTimeMs(endTimeMs);
sendFunctor.setEndTimeMs(endTimeMs);
unsigned int timeoutMs = generateTimeoutMs(endTimeMs);
if (connect(timeoutMs) != 1)
{
return -1;
}
BOOST_STATIC_ASSERT(sizeof(unsigned int) == 4);
std::vector<char> v(4+data.length());
*(unsigned int *) &v[0] = static_cast<unsigned int>(data.length());
RCF::machineToNetworkOrder(&v[0], 4, 1);
memcpy(&v[4], data.c_str(), data.length());
mError = 0;
mBytesTransferred = 0;
mTransportFilters.empty() ?
sendFunctor.getWriteFunction()(&v[0], static_cast<int>(v.size())):
mTransportFilters.front()->write(&v[0], static_cast<int>(v.size()));
if (mError || mBytesTransferred != v.size())
{
return -1;
}
return 1;
}
// return -1 for error (including timeout), otherwise bufferLen
int TcpClientTransport::timedReceive(char *buffer, std::size_t bufferLen)
{
std::size_t bytesRead = 0;
while (true)
{
mError = 0;
mBytesTransferred = 0;
mTransportFilters.empty() ?
recvFunctor.getReadFunction()(buffer, bufferLen):
mTransportFilters.front()->read(buffer, bufferLen);
if (mError == -1)
{
return -1;
}
else
{
RCF_ASSERT(0 < mBytesTransferred && mBytesTransferred <= bufferLen);
buffer += mBytesTransferred;
bufferLen -= mBytesTransferred;
bytesRead += mBytesTransferred;
if (bufferLen == 0)
{
return static_cast<int>(bytesRead);
}
}
}
}
// returns -2 for timeout, -1 for error, 0 for peer closure, 1 for ok
int TcpClientTransport::receive(std::string &data, unsigned int timeoutMs)
{
unsigned int endTimeMs = getCurrentTimeMs() + timeoutMs;
recvFunctor.setEndTimeMs(endTimeMs);
sendFunctor.setEndTimeMs(endTimeMs);
unsigned int length = 0;
char *buffer = (char *) &length;
int bufferLen = 4;
RCF_ASSERT(bufferLen == sizeof(length))(bufferLen)(sizeof(length));
int ret = timedReceive(buffer, bufferLen);
if (ret == bufferLen)
{
networkToMachineOrder(&length, sizeof(length), 1);
if (length == 0 || length > getMaxMessageLength())
{
RCF_THROW(ClientTransportException, "bad message length")(length)(getMaxMessageLength());
}
std::vector<char> vec(length);
buffer = &vec[0];
bufferLen = static_cast<int>(vec.size());
ret = timedReceive(buffer, bufferLen);
if (ret == length)
{
data.clear();
data.append(&vec[0], length);
return 1;
}
}
// error handling for timedReceive()
RCF_ASSERT(ret == -1);
BsdRecvFunctor::Status status = recvFunctor.getStatus();
if (status == BsdRecvFunctor::ConnectionReset)
{
return 0; // connection closed by peer
}
else if (status == BsdRecvFunctor::TimeOut)
{
return -2; // time out
}
else
{
return -1; // some other error
}
}
void TcpClientTransport::setCloseFunctor(CloseFunctor closeFunctor)
{
mCloseFunctor.reset( new CloseFunctor(closeFunctor) );
}
void TcpClientTransport::setRemoteAddr(const sockaddr_in &remoteAddr)
{
mRemoteAddr = remoteAddr;
}
const sockaddr_in &TcpClientTransport::getRemoteAddr()
{
return mRemoteAddr;
}
void TcpClientTransport::close()
{
if (mCloseFunctor.get())
{
(*mCloseFunctor)(fd);
mCloseFunctor.reset();
}
else
{
if (fd != -1)
{
RCF_VERIFY_SOCKETS(0 == Platform::OS::BsdSockets::closesocket(fd), "closesocket() failure")(fd);
fd = -1;
}
}
}
bool TcpClientTransport::isConnected()
{
if (fd == -1)
{
return false;
}
else
{
timeval tv = {0,0};
fd_set readFds;
FD_ZERO(&readFds);
FD_SET(fd, &readFds);
bool connected = true;
int ret = Platform::OS::BsdSockets::select(fd+1, &readFds, NULL, NULL, &tv);
if (ret == 1)
{
const int length = 1;
char buffer[length];
int ret = Platform::OS::BsdSockets::recv(fd, buffer, length, MSG_PEEK);
if (ret == 0)
{
connected = false;;
}
else if (ret == -1)
{
ret = Platform::OS::BsdSockets::GetLastError();
if (ret == Platform::OS::BsdSockets::ERR_ECONNRESET ||
ret == Platform::OS::BsdSockets::ERR_ECONNABORTED ||
ret == Platform::OS::BsdSockets::ERR_ECONNREFUSED)
{
connected = false;
}
}
}
if (!connected)
{
close();
}
return connected;
}
}
int TcpClientTransport::releaseFd()
{
int myFd = fd;
fd = -1;
recvFunctor.setFd(fd);
sendFunctor.setFd(fd);
return myFd;
}
int TcpClientTransport::getFd()
{
return fd;
}
void TcpClientTransport::setTransportFilters(const std::vector<FilterPtr> &filters)
{
mTransportFilters.assign(filters.begin(), filters.end());
if (mTransportFilters.empty())
{
//sendFunctor.setReadWriteCompletionCallback( Filter::ReadWriteCompletionCallback() );
//recvFunctor.setReadWriteCompletionCallback( Filter::ReadWriteCompletionCallback() );
sendFunctor.setReadWriteCompletionCallback( boost::bind( &TcpClientTransport::onReadWriteCompleted, this, _1, _2));
recvFunctor.setReadWriteCompletionCallback( boost::bind( &TcpClientTransport::onReadWriteCompleted, this, _1, _2));
}
else
{
RCF::connectFilters(
mTransportFilters,
recvFunctor.getReadFunction(),
sendFunctor.getWriteFunction(),
boost::bind( &TcpClientTransport::onReadWriteCompleted, this, _1, _2));
sendFunctor.setReadWriteCompletionCallback( mTransportFilters.back()->getReadWriteCompletionCallback());
recvFunctor.setReadWriteCompletionCallback( mTransportFilters.back()->getReadWriteCompletionCallback());
}
}
void TcpClientTransport::onReadWriteCompleted(std::size_t bytesTransferred, int error)
{
mBytesTransferred = bytesTransferred;
mError = error;
}
} // namespace RCF