//========= Copyright Valve Corporation, All rights reserved. ============// // // Purpose: // // $NoKeywords: $ //=============================================================================// //#define PARANOID #if defined( PARANOID ) #include #include #endif #include #include #include "tcpsocket.h" #include "tier1/utllinkedlist.h" #include #include "threadhelpers.h" #include "tier0/dbg.h" #error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead." extern TIMEVAL SetupTimeVal( double flTimeout ); extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut ); extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut ); #define SENTINEL_DISCONNECT -1 #define SENTINEL_KEEPALIVE -2 #define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS #define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long static bool g_bEnableTCPTimeout = true; class CRecvData { public: int m_Count; unsigned char m_Data[1]; }; SOCKET TCPBind( const CIPAddr *pAddr ) { // Create a socket to send and receive through. SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED ); if ( sock == INVALID_SOCKET ) { Assert( false ); return INVALID_SOCKET; } // bind to it! sockaddr_in addr; IPAddrToSockAddr( pAddr, &addr ); int status = bind( sock, (sockaddr*)&addr, sizeof(addr) ); if ( status == 0 ) { return sock; } else { closesocket( sock ); return INVALID_SOCKET; } } // ---------------------------------------------------------------------------------------- // // TCP sockets. // ---------------------------------------------------------------------------------------- // enum { OP_RECV=111, OP_SEND }; // We use this for all OVERLAPPED structures. class COverlappedPlus : public WSAOVERLAPPED { public: COverlappedPlus() { memset( this, 0, sizeof( WSAOVERLAPPED ) ); } int m_OPType; // One of the OP_ defines. }; typedef struct SendBuf_t { COverlappedPlus m_Overlapped; int m_Index; // Index into m_SendBufs. int m_DataLength; char m_Data[1]; } SendBuf_s; // These manage a thread that calls SendKeepalive() on all TCPSockets. // AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called. class CTCPSocket; void AddGlobalTCPSocket( CTCPSocket *pSocket ); void RemoveGlobalTCPSocket( CTCPSocket *pSocket ); // ------------------------------------------------------------------------------------------ // // CTCPSocket implementation. // ------------------------------------------------------------------------------------------ // class CTCPSocket : public ITCPSocket { friend class CTCPListenSocket; public: CTCPSocket() { m_Socket = INVALID_SOCKET; m_bConnected = false; m_hIOCP = NULL; m_bShouldExitThreads = false; m_bConnectionLost = false; m_nSizeBytesReceived = 0; m_pIncomingData = NULL; memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) ); m_RecvOverlapped.m_OPType = OP_RECV; m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL ); m_RecvStage = -1; m_MainThreadID = GetCurrentThreadId(); } virtual ~CTCPSocket() { Term(); CloseHandle( m_hRecvSignal ); } void Term() { Assert( GetCurrentThreadId() == m_MainThreadID ); RemoveGlobalTCPSocket( this ); if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost ) { SendDisconnectSentinel(); // Give the sends a second to complete. SO_LINGER is having trouble for some reason. WaitForSendsToComplete( 1 ); } StopThreads(); if ( m_Socket != INVALID_SOCKET ) { closesocket( m_Socket ); m_Socket = INVALID_SOCKET; } if ( m_hIOCP ) { CloseHandle( m_hIOCP ); m_hIOCP = NULL; } m_bConnected = false; m_bConnectionLost = true; m_RecvStage = -1; FOR_EACH_LL( m_SendBufs, i ) { SendBuf_t *pSendBuf = m_SendBufs[i]; ParanoidMemoryCheck( pSendBuf ); free( pSendBuf ); } m_SendBufs.Purge(); FOR_EACH_LL( m_RecvDatas, j ) { CRecvData *pRecvData = m_RecvDatas[j]; ParanoidMemoryCheck( pRecvData ); free( pRecvData ); } m_RecvDatas.Purge(); if ( m_pIncomingData ) { ParanoidMemoryCheck( m_pIncomingData ); free( m_pIncomingData ); m_pIncomingData = 0; } } virtual void Release() { delete this; } void ParanoidMemoryCheck( void *ptr = NULL ) { #if defined( PARANOID ) Assert( _CrtIsValidHeapPointer( this ) ); if ( ptr ) { Assert( _CrtIsValidHeapPointer( ptr ) ); } Assert( _CrtCheckMemory() == TRUE ); #endif } virtual bool BindToAny( const unsigned short port ) { Term(); CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY m_Socket = TCPBind( &addr ); if ( m_Socket == INVALID_SOCKET ) { return false; } else { SetInitialSocketOptions(); return true; } } // Set the initial socket options that we want. void SetInitialSocketOptions() { // Set nodelay to improve latency. BOOL val = TRUE; setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) ); // Make it linger for 3 seconds when it exits. LINGER linger; linger.l_onoff = 1; linger.l_linger = 3; setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) ); } // Called only by main thread interface functions. // Returns true if the connection is lost. bool CheckConnectionLost() { Assert( GetCurrentThreadId() == m_MainThreadID ); if ( m_Socket == SOCKET_ERROR ) return true; // Have we timed out? if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) ) { SetConnectionLost( "Connection timed out." ); } // Has any thread posted that the connection has been lost? CCriticalSectionLock postLock( &m_ConnectionLostCS ); postLock.Lock(); if ( m_bConnectionLost ) { Term(); return true; } else { return false; } } // Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost. void SetConnectionLost( const char *pErrorString, int err = -1 ) { CCriticalSectionLock postLock( &m_ConnectionLostCS ); postLock.Lock(); m_bConnectionLost = true; postLock.Unlock(); // Handle it right away if we're in the main thread. If we're in an IO thread, // it has to wait until the next interface function calls CheckConnectionLost(). if ( GetCurrentThreadId() == m_MainThreadID ) { Term(); } if ( pErrorString ) { m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 ); } else { char *lpMsgBuf; FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language (LPTSTR) &lpMsgBuf, 0, NULL ); m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 ); LocalFree( lpMsgBuf ); } } // -------------------------------------------------------------------------------------------------- // // The receive code. // -------------------------------------------------------------------------------------------------- // virtual bool StartWaitingForSize( bool bFresh ) { Assert( m_Socket != INVALID_SOCKET ); Assert( m_bConnected ); m_RecvStage = 0; m_RecvDataSize = -1; if ( bFresh ) m_nSizeBytesReceived = 0; DWORD dwNumBytesReceived = 0; WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived }; DWORD dwFlags = 0; int status = WSARecv( m_Socket, &buf, 1, &dwNumBytesReceived, &dwFlags, &m_RecvOverlapped, NULL ); int err = -1; if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) { SetConnectionLost( NULL, err ); return false; } else { return true; } } bool PostNextDataPart() { DWORD dwNumBytesReceived = 0; WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived }; DWORD dwFlags = 0; int status = WSARecv( m_Socket, &buf, 1, &dwNumBytesReceived, &dwFlags, &m_RecvOverlapped, NULL ); int err = -1; if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) { SetConnectionLost( NULL, err ); return false; } else { return true; } } bool StartWaitingForData() { Assert( m_Socket != INVALID_SOCKET ); Assert( m_RecvStage == 0 ); Assert( m_bConnected ); Assert( m_RecvDataSize > 0 ); m_RecvStage = 1; // Add a CRecvData element. ParanoidMemoryCheck(); m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize ); if ( !m_pIncomingData ) { char str[512]; _snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize ); SetConnectionLost( str ); return false; } m_pIncomingData->m_Count = m_RecvDataSize; m_AmountReceived = 0; return PostNextDataPart(); } virtual bool Recv( CUtlVector &data, double flTimeout ) { if ( CheckConnectionLost() ) return false; // Wait in 50ms chunks, checking for disconnections along the way. bool bGotData = false; DWORD msToWait = (DWORD)( flTimeout * 1000.0 ); do { DWORD curWaitTime = min( msToWait, 50 ); DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime ); if ( ret == WAIT_OBJECT_0 ) { bGotData = true; break; } // Did the connection timeout? if ( CheckConnectionLost() ) return false; msToWait -= curWaitTime; } while ( msToWait ); // If we never got a WAIT_OBJECT_0, then we never received anything. if ( !bGotData ) return false; CCriticalSectionLock csLock( &m_RecvDataCS ); csLock.Lock(); // Pickup the head m_RecvDatas element. CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ]; data.CopyArray( pRecvData->m_Data, pRecvData->m_Count ); // Now free it. m_RecvDatas.Remove( m_RecvDatas.Head() ); ParanoidMemoryCheck( pRecvData ); free( pRecvData ); // Set the event again for the next time around, if there is more data waiting. if ( m_RecvDatas.Count() > 0 ) SetEvent( m_hRecvSignal ); return true; } // INSIDE IO THREAD. void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) { if ( dwNumBytes == 0 ) { SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" ); return; } m_LastRecvTime = Plat_FloatTime(); if ( m_RecvStage == 0 ) { m_nSizeBytesReceived += dwNumBytes; if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) ) { // Size of -1 means the other size is breaking the connection. if ( m_RecvDataSize == SENTINEL_DISCONNECT ) { SetConnectionLost( "Got a graceful disconnect message." ); return; } else if ( m_RecvDataSize == SENTINEL_KEEPALIVE ) { // No data follows this. Just let m_LastRecvTime get updated. StartWaitingForSize( true ); return; } StartWaitingForData(); } else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) ) { // Handle the case where we only got some of the data (maybe one of the clients got disconnected). StartWaitingForSize( false ); } else { // This case should never ever happen! #if defined( _DEBUG ) __asm int 3; #endif SetConnectionLost( "Received too much data in a packet!" ); return; } } else if ( m_RecvStage == 1 ) { // Got the data, make sure we got it all. m_AmountReceived += dwNumBytes; // Sanity check. #if defined( _DEBUG ) Assert( m_RecvDataSize == m_pIncomingData->m_Count ); Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads. #endif if ( m_AmountReceived == m_RecvDataSize ) { m_RecvStage = 2; // Add the data to the list of packets waiting to be picked up. CCriticalSectionLock csLock( &m_RecvDataCS ); csLock.Lock(); m_RecvDatas.AddToTail( m_pIncomingData ); m_pIncomingData = NULL; if ( m_RecvDatas.Count() == 1 ) SetEvent( m_hRecvSignal ); // Notify the Recv() function. StartWaitingForSize( true ); } else { PostNextDataPart(); } } else { Assert( false ); } } // -------------------------------------------------------------------------------------------------- // // The send code. // -------------------------------------------------------------------------------------------------- // virtual void WaitForSendsToComplete( double flTimeout ) { CWaitTimer waitTimer( flTimeout ); while ( 1 ) { CCriticalSectionLock sendBufLock( &m_SendCS ); sendBufLock.Lock(); if( m_SendBufs.Count() == 0 ) return; sendBufLock.Unlock(); if ( waitTimer.ShouldKeepWaiting() ) Sleep( 10 ); else break; } } // This is called in the keepalive thread. void SendKeepalive() { // Send a message saying we're exiting. ParanoidMemoryCheck(); SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); if ( !pBuf ) { SetConnectionLost( "malloc() in SendKeepalive() failed." ); return; } pBuf->m_DataLength = sizeof( int ); *((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE; InternalSendDataBuf( pBuf ); } void SendDisconnectSentinel() { // Send a message saying we're exiting. ParanoidMemoryCheck(); SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); if ( pBuf ) { pBuf->m_DataLength = sizeof( int ); *((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting. InternalSendDataBuf( pBuf ); } } virtual bool Send( const void *pData, int len ) { const void *pChunks[1] = { pData }; int chunkLengths[1] = { len }; return SendChunks( pChunks, chunkLengths, 1 ); } virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks ) { if ( CheckConnectionLost() ) return false; CChunkWalker walker( pChunks, pChunkLengths, nChunks ); int totalLength = walker.GetTotalLength(); if ( !totalLength ) return true; // Create a buffer to hold the data and copy the data in. ParanoidMemoryCheck(); SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) ); if ( !pBuf ) { char str[512]; _snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength ); SetConnectionLost( str ); return false; } pBuf->m_DataLength = totalLength + sizeof( int ); int *pByteCountPos = (int*)pBuf->m_Data; *pByteCountPos = totalLength; char *pDataPos = &pBuf->m_Data[ sizeof( int ) ]; walker.CopyTo( pDataPos, totalLength ); int status = InternalSendDataBuf( pBuf ); int err = -1; if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) { SetConnectionLost( NULL, err ); return false; } else { return true; } } int InternalSendDataBuf( SendBuf_t *pBuf ) { // Protect against interference from the keepalive thread. CCriticalSectionLock csLock( &m_SendCS ); csLock.Lock(); pBuf->m_Overlapped.m_OPType = OP_SEND; pBuf->m_Overlapped.hEvent = NULL; // Add it to our list of buffers. pBuf->m_Index = m_SendBufs.AddToTail( pBuf ); // Tell Winsock to send it. WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data }; DWORD dwNumBytesSent = 0; return WSASend( m_Socket, &buf, 1, &dwNumBytesSent, 0, &pBuf->m_Overlapped, NULL ); } // INSIDE IO THREAD. void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) { if ( dwNumBytes == 0 ) { SetConnectionLost( "0 bytes in HandleSendCompletion." ); return; } // Just free the buffer. SendBuf_t *pBuf = (SendBuf_t*)pInfo; Assert( dwNumBytes == (DWORD)pBuf->m_DataLength ); CCriticalSectionLock sendBufLock( &m_SendCS ); sendBufLock.Lock(); m_SendBufs.Remove( pBuf->m_Index ); sendBufLock.Unlock(); ParanoidMemoryCheck( pBuf ); free( pBuf ); } // -------------------------------------------------------------------------------------------------- // // The connect code. // -------------------------------------------------------------------------------------------------- // virtual bool BeginConnect( const CIPAddr &inputAddr ) { sockaddr_in addr; IPAddrToSockAddr( &inputAddr, &addr ); m_bConnected = false; int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) ); ret=ret; return true; } virtual bool UpdateConnect() { // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout. fd_set writeSet; writeSet.fd_count = 1; writeSet.fd_array[0] = m_Socket; TIMEVAL timeVal = SetupTimeVal( 0 ); // See if it has a packet waiting. int status = select( 0, NULL, &writeSet, NULL, &timeVal ); if ( status > 0 ) { SetupConnected(); return true; } return false; } void SetupConnected() { m_bConnected = true; m_bConnectionLost = false; m_LastRecvTime = Plat_FloatTime(); CreateThreads(); StartWaitingForSize( true ); AddGlobalTCPSocket( this ); } virtual bool IsConnected() { CheckConnectionLost(); return m_bConnected; } virtual void GetDisconnectReason( CUtlVector &reason ) { reason = m_ErrorString; } // -------------------------------------------------------------------------------------------------- // // Threads code. // -------------------------------------------------------------------------------------------------- // // Create our IO Completion Port threads. bool CreateThreads() { int nThreads = 1; SetShouldExitThreads( false ); // Create our IO completion port and hook it to our socket. m_hIOCP = CreateIoCompletionPort( INVALID_HANDLE_VALUE, NULL, 0, 0); m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads ); for ( int i=0; i < nThreads; i++ ) { DWORD dwThreadID = 0; HANDLE hThread = CreateThread( NULL, 0, &CTCPSocket::StaticThreadFn, this, 0, &dwThreadID ); if ( hThread ) { SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL ); m_Threads.AddToTail( hThread ); } else { StopThreads(); return false; } } return true; } void StopThreads() { // Tell the threads to exit, then wait for them to do so. SetShouldExitThreads( true ); WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE ); for ( int i=0; i < m_Threads.Count(); i++ ) { CloseHandle( m_Threads[i] ); } m_Threads.Purge(); } void SetShouldExitThreads( bool bShouldExit ) { CCriticalSectionLock lock( &m_ThreadsCS ); lock.Lock(); m_bShouldExitThreads = bShouldExit; } bool ShouldExitThreads() { CCriticalSectionLock lock( &m_ThreadsCS ); lock.Lock(); bool bRet = m_bShouldExitThreads; return bRet; } DWORD ThreadFn() { while ( 1 ) { DWORD dwNumBytes = 0; unsigned long pInputTCPSocket; LPOVERLAPPED pOverlapped; if ( GetQueuedCompletionStatus( m_hIOCP, // the port we're listening on &dwNumBytes, // # bytes received on the port &pInputTCPSocket,// "completion key" = CTCPSocket* &pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend. 100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted. ) ) { COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped; ParanoidMemoryCheck( pInfo ); if ( pInfo->m_OPType == OP_RECV ) { Assert( pInfo == &m_RecvOverlapped ); HandleRecvCompletion( pInfo, dwNumBytes ); } else { Assert( pInfo->m_OPType == OP_SEND ); HandleSendCompletion( pInfo, dwNumBytes ); } } if ( ShouldExitThreads() ) break; } return 0; } static DWORD WINAPI StaticThreadFn( LPVOID pParameter ) { return ((CTCPSocket*)pParameter)->ThreadFn(); } private: SOCKET m_Socket; bool m_bConnected; // m_RecvOverlapped is setup to first wait for the size, then the data. // Then it is not posted until the app grabs the data. HANDLE m_hRecvSignal; // Tells Recv() when we have data. COverlappedPlus m_RecvOverlapped; int m_RecvStage; // -1 = not initialized // 0 = waiting for size // 1 = waiting for data // 2 = waiting for app to pickup the data CUtlLinkedList m_RecvDatas; // The head element is the next one to be picked up. CRecvData *m_pIncomingData; // The packet we're currently receiving. CCriticalSection m_RecvDataCS; // This protects adds and removes in the list. // These reference the element at the tail of m_RecvData. It is the current one getting volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet? int m_RecvDataSize; // this is received over the network int m_AmountReceived; // How much we've received so far. // Last time we received anything from this connection. Used to determine if the connection is // still active. double m_LastRecvTime; // Outgoing send buffers. CUtlLinkedList m_SendBufs; CCriticalSection m_SendCS; // All the threads waiting for IO. CUtlVector m_Threads; HANDLE m_hIOCP; // Used during shutdown. volatile bool m_bShouldExitThreads; CCriticalSection m_ThreadsCS; // For debugging. DWORD m_MainThreadID; // Set by the main thread or IO threads to signal connection lost. bool m_bConnectionLost; CCriticalSection m_ConnectionLostCS; // This is set when we get disconnected. CUtlVector m_ErrorString; }; // ------------------------------------------------------------------------------------------ // // ITCPListenSocket implementation. // ------------------------------------------------------------------------------------------ // class CTCPListenSocket : public ITCPListenSocket { public: CTCPListenSocket() { m_Socket = INVALID_SOCKET; } virtual ~CTCPListenSocket() { if ( m_Socket != INVALID_SOCKET ) { closesocket( m_Socket ); } } // The main function to create one of these suckers. static ITCPListenSocket* Create( const unsigned short port, int nQueueLength ) { CTCPListenSocket *pRet = new CTCPListenSocket; if ( !pRet ) return NULL; // Bind it to a socket and start listening. CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY pRet->m_Socket = TCPBind( &addr ); if ( pRet->m_Socket == INVALID_SOCKET || listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 ) { pRet->Release(); return false; } return pRet; } virtual void Release() { delete this; } virtual ITCPSocket* UpdateListen( CIPAddr *pAddr ) { // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout. fd_set readSet; readSet.fd_count = 1; readSet.fd_array[0] = m_Socket; TIMEVAL timeVal = SetupTimeVal( 0 ); // Wait until it connects. int status = select( 0, &readSet, NULL, NULL, &timeVal ); if ( status > 0 ) { sockaddr_in addr; int addrSize = sizeof( addr ); // Now accept the final connection. SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize ); if ( newSock == INVALID_SOCKET ) { Assert( false ); } else { CTCPSocket *pRet = new CTCPSocket; if ( !pRet ) { closesocket( newSock ); return NULL; } pRet->m_Socket = newSock; pRet->SetInitialSocketOptions(); pRet->SetupConnected(); // Report the address.. SockAddrToIPAddr( &addr, pAddr ); return pRet; } } return NULL; } private: SOCKET m_Socket; }; ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength ) { return CTCPListenSocket::Create( port, nQueueLength ); } ITCPSocket* CreateTCPSocket() { return new CTCPSocket; } void TCPSocket_EnableTimeout( bool bEnable ) { g_bEnableTCPTimeout = bEnable; } // --------------------------------------------------------------------------------- // // This thread sends keepalives on all active TCP sockets. // --------------------------------------------------------------------------------- // HANDLE g_hKeepaliveThread; HANDLE g_hKeepaliveThreadSignal; HANDLE g_hKeepaliveThreadReply; CUtlLinkedList g_TCPSockets; CCriticalSection g_TCPSocketsCS; DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter ) { while ( 1 ) { if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 ) break; // Tell all TCP sockets to send a keepalive. CCriticalSectionLock csLock( &g_TCPSocketsCS ); csLock.Lock(); FOR_EACH_LL( g_TCPSockets, i ) { g_TCPSockets[i]->SendKeepalive(); } } SetEvent( g_hKeepaliveThreadReply ); return 0; } void AddGlobalTCPSocket( CTCPSocket *pSocket ) { CCriticalSectionLock csLock( &g_TCPSocketsCS ); csLock.Lock(); Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() ); g_TCPSockets.AddToTail( pSocket ); // If this is the first one, create the keepalive thread. if ( g_TCPSockets.Count() == 1 ) { g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL ); g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL ); DWORD dwThreadID = 0; g_hKeepaliveThread = CreateThread( NULL, 0, TCPKeepaliveThread, NULL, 0, &dwThreadID ); } } void RemoveGlobalTCPSocket( CTCPSocket *pSocket ) { bool bThreadRunning = false; DWORD dwExitCode = 0; if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE ) { bThreadRunning = true; } CCriticalSectionLock csLock( &g_TCPSocketsCS ); csLock.Lock(); int index = g_TCPSockets.Find( pSocket ); if ( index != g_TCPSockets.InvalidIndex() ) { g_TCPSockets.Remove( index ); // If this was the last one, delete the thread. if ( g_TCPSockets.Count() == 0 ) { csLock.Unlock(); if ( bThreadRunning ) { SetEvent( g_hKeepaliveThreadSignal ); WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE ); } CloseHandle( g_hKeepaliveThreadSignal ); CloseHandle( g_hKeepaliveThreadReply ); CloseHandle( g_hKeepaliveThread ); return; } } csLock.Unlock(); }