// CeSocket.cpp (CeSocketException, CeSockAddr, CeSocket, CHttpBlockingSocket)
#include "stdafx.h"
#include "CeSocket.h"
#include "CeThread.h"
#pragma warning(disable: 4390) // empty statement
///////////////////////////////////////////////////////////////////////////////
//
class CeSocketConnectThread: public CeWorkerThread
{
public:
SOCKET* m_pSocket;
LPCSOCKADDR m_pSA;
DWORD m_dwErr;
CeSocketConnectThread(SOCKET* pSocket, LPCSOCKADDR pSA)
{ m_pSocket = pSocket; m_pSA = pSA; m_dwErr = NOERROR; }
// do the connect
virtual UINT ThreadProc()
{
if (connect(*m_pSocket, m_pSA, sizeof(SOCKADDR)) == SOCKET_ERROR)
{
m_dwErr = WSAGetLastError();
TRACE(_T("Connect: failed %d\n"), m_dwErr);
return 1;
}
return 0;
}
// kill the connect, allowing thread to terminate
virtual void OnStop()
{ closesocket(*m_pSocket); }
};
///////////////////////////////////////////////////////////////////////////////
// class CeSocket
void CeSocket::Cleanup()
{
// doesn't throw an exception because it's called in a catch block
if (m_hSocket == INVALID_SOCKET)
return;
//VERIFY(closesocket(m_hSocket) != SOCKET_ERROR);
closesocket(m_hSocket);
m_hSocket = INVALID_SOCKET;
}
BOOL CeSocket::Create(int nType /* = SOCK_STREAM */)
{
_ASSERTE(m_hSocket == INVALID_SOCKET);
if ((m_hSocket = socket(AF_INET, nType, 0)) == INVALID_SOCKET)
{
TRACE(_T("Socket: failed %d\n"), GetLastError());
return FALSE;
}
return TRUE;
}
BOOL CeSocket::Bind(LPCSOCKADDR psa)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(psa != NULL);
if (bind(m_hSocket, psa, sizeof(SOCKADDR)) == SOCKET_ERROR)
{
TRACE(_T("Bind: failed %d\n"), GetLastError());
return FALSE;
}
return TRUE;
}
BOOL CeSocket::Listen()
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
if (listen(m_hSocket, 5) == SOCKET_ERROR)
{
TRACE(_T("Listen: failed %d\n"), GetLastError());
return FALSE;
}
return TRUE;
}
BOOL CeSocket::Accept(CeSocket& sConnect, LPSOCKADDR psa)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(sConnect.m_hSocket == INVALID_SOCKET);
_ASSERTE(psa != NULL);
int nLengthAddr = sizeof(SOCKADDR);
sConnect.m_hSocket = accept(m_hSocket, psa, &nLengthAddr);
if (sConnect == INVALID_SOCKET)
{
// reset the address for this, since we may want to use
// the passed socket object for something else
sConnect.m_hSocket = INVALID_SOCKET;
// no exception if the listen was canceled
if(GetLastError() != WSAEINTR)
TRACE(_T("Accept: Cancelled\n"));
return FALSE;
}
return TRUE;
}
void CeSocket::Close()
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
if (closesocket(m_hSocket) == SOCKET_ERROR)
{
TRACE(_T("Close: failed %d\n"), GetLastError());
// should be OK to close if closed already
//RaiseException(STATUS_CE_SOCKET_EXCEPTION, 0, 0, 0);
}
m_hSocket = INVALID_SOCKET;
}
BOOL CeSocket::Connect(LPCSOCKADDR psa, const DWORD dwMilliSecs /*= INFINITE*/)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(psa != NULL);
if (dwMilliSecs == INFINITE)
{
// should timeout by itself, but the machine dependant timeout lengths
// aremnt available, so we implement it using a thread
if (connect(m_hSocket, psa, sizeof(SOCKADDR)) == SOCKET_ERROR)
{
TRACE(_T("Connect: failed %d\n"), GetLastError());
return FALSE;
}
}
else
{
CeSocketConnectThread connectThread(&m_hSocket, psa);
// initiate the connect
connectThread.Start();
// wait for it to complete
DWORD dwRet = connectThread.WaitFor(dwMilliSecs);
// check the result
if (WAIT_OBJECT_0 != dwRet)
{
// WAIT_TIMEOUT, WAIT_FAILED, WAIT_ABANDONED_0
// wait for complete, kill if it doesn't
connectThread.Stop(5000);
TRACE(_T("Connect: failed %d\n"), WSAETIMEDOUT);
// set the error to match what we've actually implemented
WSASetLastError(WSAETIMEDOUT);
return FALSE;
}
// thread exitted, check for failure
if (connectThread.m_dwErr != 0)
{
WSASetLastError(connectThread.m_dwErr);
return FALSE;
}
}
return TRUE;
}
int CeSocket::Write(const BYTE* pch, const int nSize, const DWORD dwMilliSecs)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(pch != NULL);
_ASSERTE(nSize >= 0);
int nBytesSent = 0;
int nBytesThisTime;
const BYTE* pch1 = pch;
do
{
nBytesThisTime = Send(pch1, nSize - nBytesSent, dwMilliSecs);
nBytesSent += nBytesThisTime;
pch1 += nBytesThisTime;
}
while(nBytesSent < nSize);
return nBytesSent;
}
int CeSocket::Send(const BYTE* pch, const int nSize, const DWORD dwMilliSecs)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(pch != NULL);
_ASSERTE(nSize >= 0);
// returned value will be less than nSize if client cancels the reading
FD_SET fd = {1, m_hSocket};
TIMEVAL tv = {dwMilliSecs/1000, (dwMilliSecs % 1000)*1000};
if (select(0, NULL, &fd, NULL, &tv) == 0)
{
TRACE(_T("select: failed %d\n"), GetLastError());
return -1;
}
int nBytesSent;
if ((nBytesSent = send(m_hSocket, (char *)pch, nSize, 0)) == SOCKET_ERROR)
{
TRACE(_T("send: failed %d\n"), GetLastError());
return -1;
}
return nBytesSent;
}
int CeSocket::Receive(BYTE* pch, const int nSize, const DWORD dwMilliSecs)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(pch != NULL);
_ASSERTE(nSize >= 0);
FD_SET fd = {1, m_hSocket};
TIMEVAL tv = {dwMilliSecs/1000, (dwMilliSecs % 1000)*1000};
// only waiting on one socket, check for error or timeout
switch (select(0, &fd, NULL, NULL, &tv))
{
case SOCKET_ERROR: // error
TRACE(_T("select: failed (%d)\n"), GetLastError());
return -1;
case 0: // timout
TRACE(_T("select: timeout (%d)\n"), dwMilliSecs);
return 0;
default: // got data, fall through and receive below
;
}
int nBytesReceived;
if ((nBytesReceived = recv(m_hSocket, (char *)pch, nSize, 0)) == SOCKET_ERROR)
{
TRACE(_T("recv: failed %d\n"), GetLastError());
return -1;
}
return nBytesReceived;
}
int CeSocket::ReceiveFrom(BYTE* pch, const int nSize, LPSOCKADDR psa, const DWORD dwMilliSecs)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(pch != NULL);
_ASSERTE(nSize >= 0);
_ASSERTE(psa != NULL);
FD_SET fd = {1, m_hSocket};
TIMEVAL tv = {dwMilliSecs/1000, (dwMilliSecs % 1000)*1000};
// only waiting on one socket, check for error or timeout
switch (select(0, &fd, NULL, NULL, &tv))
{
case SOCKET_ERROR: // error
TRACE(_T("select: failed %d\n"), GetLastError());
return -1;
case 0: // timout
return 0;
default: // got data, fall through and receive below
;
}
// input buffer should be big enough for the entire datagram
int nFromSize = sizeof(SOCKADDR);
int nBytesReceived = recvfrom(m_hSocket, (char *)pch, nSize, 0, psa, &nFromSize);
if (nBytesReceived == SOCKET_ERROR)
{
TRACE(_T("recvfrom: failed %d\n"), GetLastError());
return -1;
}
return nBytesReceived;
}
int CeSocket::SendTo(const BYTE* pch, const int nSize, LPCSOCKADDR psa, const DWORD dwMilliSecs)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(pch != NULL);
_ASSERTE(nSize >= 0);
_ASSERTE(psa != NULL);
FD_SET fd = {1, m_hSocket};
TIMEVAL tv = {dwMilliSecs/1000, (dwMilliSecs % 1000)*1000};
if (select(0, NULL, &fd, NULL, &tv) == 0)
{
TRACE(_T("select: failed %d\n"), GetLastError());
return -1;
}
int nBytesSent = sendto(m_hSocket, (char *)pch, nSize, 0, psa, sizeof(SOCKADDR));
if (nBytesSent == SOCKET_ERROR)
{
TRACE(_T("sendto: failed %d\n"), GetLastError());
return -1;
}
return nBytesSent;
}
BOOL CeSocket::GetPeerAddr(LPSOCKADDR psa) const
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(psa != NULL);
// gets the address of the socket at the other end
int nLengthAddr = sizeof(SOCKADDR);
if (getpeername(m_hSocket, psa, &nLengthAddr) == SOCKET_ERROR)
{
TRACE(_T("getpeername: failed %d\n"), GetLastError());
return FALSE;
}
return TRUE;
}
BOOL CeSocket::GetSockAddr(LPSOCKADDR psa) const
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
_ASSERTE(psa != NULL);
// gets the address of the socket at this end
int nLengthAddr = sizeof(SOCKADDR);
if (getsockname(m_hSocket, psa, &nLengthAddr) == SOCKET_ERROR)
{
TRACE(_T("getsockname: failed %d\n"), GetLastError());
return FALSE;
}
return TRUE;
}
BOOL CeSocket::SetOption(int nLevel, int nOption, const void* pVal, int nValLen)
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
if (SOCKET_ERROR == setsockopt(m_hSocket, nLevel, nOption, (char *) pVal, nValLen))
{
TRACE(_T("setsockopt: failed %d\n"), GetLastError());
return FALSE;
}
return TRUE;
}
BOOL CeSocket::GetOption(int nLevel, int nOption, void* pVal, int* pValLen) const
{
_ASSERTE(m_hSocket != INVALID_SOCKET);
if (SOCKET_ERROR == getsockopt(m_hSocket, nLevel, nOption, (char *) pVal, pValLen))
{
TRACE(_T("getsockopt: failed %d\n"), GetLastError());
return FALSE;
}
return TRUE;
}
//static
CeSockAddr CeSocket::GetHostByName(LPCTSTR pchName, const USHORT ushPort /* = 0 */)
{
_ASSERTE(pchName != NULL);
// convert name from unicode
char szBuf[128];
for (int ii = 0; ii < 128 && *pchName != _T('0'); pchName++, ii++)
szBuf[ii] = (char) *pchName;
szBuf[ii] = 0;
SOCKADDR_IN sockTemp;
memset(&sockTemp, 0, sizeof sockTemp);
hostent* pHostEnt = gethostbyname(szBuf);
if (pHostEnt == NULL)
{
TRACE(_T("gethostbyname: failed %d\n"), WSAGetLastError());
}
else
{
ULONG* pulAddr = (ULONG*) pHostEnt->h_addr_list[0];
sockTemp.sin_family = AF_INET;
sockTemp.sin_port = htons(ushPort);
sockTemp.sin_addr.s_addr = *pulAddr; // address is already in network byte order
}
return sockTemp;
}
#ifdef UNICODE
//static
CeSockAddr CeSocket::GetHostByName(const char *pchName, const USHORT ushPort /* = 0 */)
{
_ASSERTE(pchName != NULL);
SOCKADDR_IN sockTemp;
memset(&sockTemp, 0, sizeof sockTemp);
hostent* pHostEnt = gethostbyname(pchName);
if (pHostEnt == NULL)
{
TRACE(_T("gethostbyname: failed %d\n"), WSAGetLastError());
}
else
{
ULONG* pulAddr = (ULONG*) pHostEnt->h_addr_list[0];
sockTemp.sin_family = AF_INET;
sockTemp.sin_port = htons(ushPort);
sockTemp.sin_addr.s_addr = *pulAddr; // address is already in network byte order
}
return sockTemp;
}
#endif // UNICODE
//static
const char* CeSocket::GetHostByAddr(LPCSOCKADDR psa)
{
_ASSERTE(psa != NULL);
hostent* pHostEnt = gethostbyaddr((char*) &((LPSOCKADDR_IN) psa)->sin_addr.s_addr, 4, PF_INET);
if (pHostEnt == NULL)
{
TRACE(_T("gethostbyaddr: failed %d\n"), WSAGetLastError());
return NULL;
}
return pHostEnt->h_name; // caller shouldn't delete this memory
}