hl2_src-leak-2017/src/utils/vmpi/ThreadedTCPSocketEmu.cpp

345 lines
6.8 KiB
C++

//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose:
//
// $NoKeywords: $
//=============================================================================//
#include <windows.h>
#include "tcpsocket.h"
#include "IThreadedTCPSocket.h"
#include "ThreadedTCPSocketEmu.h"
#include "ThreadHelpers.h"
// ---------------------------------------------------------------------------------------- //
// CThreadedTCPSocketEmu. This uses IThreadedTCPSocket to emulate the polling-type interface
// in ITCPSocket.
// ---------------------------------------------------------------------------------------- //
// This class uses the IThreadedTCPSocket interface to emulate the old ITCPSocket.
class CThreadedTCPSocketEmu : public ITCPSocket, public ITCPSocketHandler, public IHandlerCreator
{
public:
CThreadedTCPSocketEmu()
{
m_pSocket = NULL;
m_LocalPort = 0xFFFF;
m_pConnectSocket = NULL;
m_RecvPacketsEvent.Init( false, false );
m_bError = false;
}
virtual ~CThreadedTCPSocketEmu()
{
Term();
}
void Init( IThreadedTCPSocket *pSocket )
{
m_pSocket = pSocket;
}
void Term()
{
if ( m_pSocket )
{
m_pSocket->Release();
m_pSocket = NULL;
}
if ( m_pConnectSocket )
{
m_pConnectSocket->Release();
m_pConnectSocket = NULL;
}
}
// ITCPSocketHandler implementation.
private:
virtual void OnPacketReceived( CTCPPacket *pPacket )
{
CCriticalSectionLock csLock( &m_RecvPacketsCS );
csLock.Lock();
m_RecvPackets.AddToTail( pPacket );
m_RecvPacketsEvent.SetEvent();
}
virtual void OnError( int errorCode, const char *pErrorString )
{
CCriticalSectionLock csLock( &m_ErrorStringCS );
csLock.Lock();
m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 );
m_bError = true;
}
// IHandlerCreator implementation.
public:
// This is used for connecting.
virtual ITCPSocketHandler* CreateNewHandler()
{
return this;
}
// ITCPSocket implementation.
public:
virtual void Release()
{
delete this;
}
virtual bool BindToAny( const unsigned short port )
{
m_LocalPort = port;
return true;
}
virtual bool BeginConnect( const CIPAddr &addr )
{
// They should have "bound" to a port before trying to connect.
Assert( m_LocalPort != 0xFFFF );
if ( m_pConnectSocket )
m_pConnectSocket->Release();
m_pConnectSocket = ThreadedTCP_CreateConnector(
addr,
CIPAddr( 0, 0, 0, 0, m_LocalPort ),
this );
return m_pConnectSocket != 0;
}
virtual bool UpdateConnect()
{
Assert( !m_pSocket );
if ( !m_pConnectSocket )
return false;
if ( m_pConnectSocket->Update( &m_pSocket ) )
{
if ( m_pSocket )
{
// Ok, we're connected now.
m_pConnectSocket->Release();
m_pConnectSocket = NULL;
return true;
}
else
{
return false;
}
}
else
{
Assert( false );
m_pConnectSocket->Release();
m_pConnectSocket = NULL;
return false;
}
}
virtual bool IsConnected()
{
if ( m_bError )
{
Term();
return false;
}
else
{
return m_pSocket != NULL;
}
}
virtual void GetDisconnectReason( CUtlVector<char> &reason )
{
CCriticalSectionLock csLock( &m_ErrorStringCS );
csLock.Lock();
reason = m_ErrorString;
}
virtual bool Send( const void *pData, int size )
{
Assert( m_pSocket );
if ( !m_pSocket )
return false;
return m_pSocket->Send( pData, size );
}
virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks )
{
Assert( m_pSocket );
if ( !m_pSocket || !m_pSocket->IsValid() )
return false;
return m_pSocket->SendChunks( pChunks, pChunkLengths, nChunks );
}
virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout )
{
// Use our m_RecvPacketsEvent event to determine if there is data to receive yet.
DWORD nMilliseconds = (DWORD)( flTimeout * 1000.0f );
DWORD ret = WaitForSingleObject( m_RecvPacketsEvent.GetEventHandle(), nMilliseconds );
if ( ret == WAIT_OBJECT_0 )
{
// Ok, there's a packet.
CCriticalSectionLock csLock( &m_RecvPacketsCS );
csLock.Lock();
Assert( m_RecvPackets.Count() > 0 );
int iHead = m_RecvPackets.Head();
CTCPPacket *pPacket = m_RecvPackets[ iHead ];
data.CopyArray( (const unsigned char*)pPacket->GetData(), pPacket->GetLen() );
pPacket->Release();
m_RecvPackets.Remove( iHead );
// Re-set the event if there are more packets left to receive.
if ( m_RecvPackets.Count() > 0 )
{
m_RecvPacketsEvent.SetEvent();
}
return true;
}
else
{
return false;
}
}
private:
IThreadedTCPSocket *m_pSocket;
unsigned short m_LocalPort; // The port we bind to when we want to connect.
ITCPConnectSocket *m_pConnectSocket;
// All the received data is stored in here.
CEvent m_RecvPacketsEvent;
CCriticalSection m_RecvPacketsCS;
CUtlLinkedList<CTCPPacket*, int> m_RecvPackets;
CCriticalSection m_ErrorStringCS;
CUtlVector<char> m_ErrorString;
bool m_bError; // Set to true when there's an error. Next chance we get in the main thread, we'll close the socket.
};
ITCPSocket* CreateTCPSocketEmu()
{
return new CThreadedTCPSocketEmu;
}
// ---------------------------------------------------------------------------------------- //
// CThreadedTCPListenSocketEmu implementation.
// ---------------------------------------------------------------------------------------- //
class CThreadedTCPListenSocketEmu : public ITCPListenSocket, public IHandlerCreator
{
public:
CThreadedTCPListenSocketEmu()
{
m_pListener = NULL;
m_pLastCreatedSocket = NULL;
}
virtual ~CThreadedTCPListenSocketEmu()
{
if ( m_pListener )
m_pListener->Release();
}
bool StartListening( const unsigned short port, int nQueueLength )
{
m_pListener = ThreadedTCP_CreateListener(
this,
port,
nQueueLength );
return m_pListener != 0;
}
// ITCPListenSocket implementation.
private:
virtual void Release()
{
delete this;
}
virtual ITCPSocket* UpdateListen( CIPAddr *pAddr )
{
if ( !m_pListener )
return NULL;
IThreadedTCPSocket *pSocket;
if ( m_pListener->Update( &pSocket ) && pSocket )
{
*pAddr = pSocket->GetRemoteAddr();
// This is pretty hacky, but this stuff is just around for test code.
CThreadedTCPSocketEmu *pLast = m_pLastCreatedSocket;
pLast->Init( pSocket );
m_pLastCreatedSocket = NULL;
return pLast;
}
else
{
return NULL;
}
}
// IHandlerCreator implementation.
private:
virtual ITCPSocketHandler* CreateNewHandler()
{
m_pLastCreatedSocket = new CThreadedTCPSocketEmu;
return m_pLastCreatedSocket;
}
private:
ITCPConnectSocket *m_pListener;
CThreadedTCPSocketEmu *m_pLastCreatedSocket;
};
ITCPListenSocket* CreateTCPListenSocketEmu( const unsigned short port, int nQueueLength )
{
CThreadedTCPListenSocketEmu *pSocket = new CThreadedTCPListenSocketEmu;
if ( pSocket->StartListening( port, nQueueLength ) )
{
return pSocket;
}
else
{
delete pSocket;
return NULL;
}
}