hl2_src-leak-2017/src/utils/vmpi/iphelpers.cpp

611 lines
12 KiB
C++

//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose:
//
// $NoKeywords: $
//
//=============================================================================//
#pragma warning (disable:4127)
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma warning (default:4127)
#include "iphelpers.h"
#include "basetypes.h"
#include <assert.h>
#include "utllinkedlist.h"
#include "utlvector.h"
#include "tier1/strtools.h"
// This automatically calls WSAStartup for the app at startup.
class CIPStarter
{
public:
CIPStarter()
{
WSADATA wsaData;
WSAStartup( WINSOCK_VERSION, &wsaData );
}
};
static CIPStarter g_Starter;
unsigned long SampleMilliseconds()
{
CCycleCount cnt;
cnt.Sample();
return cnt.GetMilliseconds();
}
// ------------------------------------------------------------------------------------------ //
// CChunkWalker.
// ------------------------------------------------------------------------------------------ //
CChunkWalker::CChunkWalker( void const * const *pChunks, const int *pChunkLengths, int nChunks )
{
m_TotalLength = 0;
for ( int i=0; i < nChunks; i++ )
m_TotalLength += pChunkLengths[i];
m_iCurChunk = 0;
m_iCurChunkPos = 0;
m_pChunks = pChunks;
m_pChunkLengths = pChunkLengths;
m_nChunks = nChunks;
}
int CChunkWalker::GetTotalLength() const
{
return m_TotalLength;
}
void CChunkWalker::CopyTo( void *pOut, int nBytes )
{
unsigned char *pOutPos = (unsigned char*)pOut;
int nBytesLeft = nBytes;
while ( nBytesLeft > 0 )
{
int toCopy = nBytesLeft;
int curChunkLen = m_pChunkLengths[m_iCurChunk];
int amtLeft = curChunkLen - m_iCurChunkPos;
if ( nBytesLeft > amtLeft )
{
toCopy = amtLeft;
}
unsigned char *pCurChunkData = (unsigned char*)m_pChunks[m_iCurChunk];
memcpy( pOutPos, &pCurChunkData[m_iCurChunkPos], toCopy );
nBytesLeft -= toCopy;
pOutPos += toCopy;
// Slide up to the next chunk if we're done with the one we're on.
m_iCurChunkPos += toCopy;
assert( m_iCurChunkPos <= curChunkLen );
if ( m_iCurChunkPos == curChunkLen )
{
++m_iCurChunk;
m_iCurChunkPos = 0;
if ( m_iCurChunk == m_nChunks )
{
assert( nBytesLeft == 0 );
}
}
}
}
// ------------------------------------------------------------------------------------------ //
// CWaitTimer
// ------------------------------------------------------------------------------------------ //
bool g_bForceWaitTimers = false;
CWaitTimer::CWaitTimer( double flSeconds )
{
m_StartTime = SampleMilliseconds();
m_WaitMS = (unsigned long)( flSeconds * 1000.0 );
}
bool CWaitTimer::ShouldKeepWaiting()
{
if ( m_WaitMS == 0 )
{
return false;
}
else
{
return ( SampleMilliseconds() - m_StartTime ) <= m_WaitMS || g_bForceWaitTimers;
}
}
// ------------------------------------------------------------------------------------------ //
// CIPAddr.
// ------------------------------------------------------------------------------------------ //
CIPAddr::CIPAddr()
{
Init( 0, 0, 0, 0, 0 );
}
CIPAddr::CIPAddr( const int inputIP[4], const int inputPort )
{
Init( inputIP[0], inputIP[1], inputIP[2], inputIP[3], inputPort );
}
CIPAddr::CIPAddr( int ip0, int ip1, int ip2, int ip3, int ipPort )
{
Init( ip0, ip1, ip2, ip3, ipPort );
}
void CIPAddr::Init( int ip0, int ip1, int ip2, int ip3, int ipPort )
{
ip[0] = (unsigned char)ip0;
ip[1] = (unsigned char)ip1;
ip[2] = (unsigned char)ip2;
ip[3] = (unsigned char)ip3;
port = (unsigned short)ipPort;
}
bool CIPAddr::operator==( const CIPAddr &o ) const
{
return ip[0] == o.ip[0] && ip[1] == o.ip[1] && ip[2] == o.ip[2] && ip[3] == o.ip[3] && port == o.port;
}
bool CIPAddr::operator!=( const CIPAddr &o ) const
{
return !( *this == o );
}
void CIPAddr::SetupLocal( int inPort )
{
ip[0] = 0x7f;
ip[1] = 0;
ip[2] = 0;
ip[3] = 1;
port = inPort;
}
// ------------------------------------------------------------------------------------------ //
// Static helpers.
// ------------------------------------------------------------------------------------------ //
static double IP_FloatTime()
{
CCycleCount cnt;
cnt.Sample();
return cnt.GetSeconds();
}
TIMEVAL SetupTimeVal( double flTimeout )
{
TIMEVAL timeVal;
timeVal.tv_sec = (long)flTimeout;
timeVal.tv_usec = (long)( (flTimeout - (long)flTimeout) * 1000.0 );
return timeVal;
}
// Convert a CIPAddr to a sockaddr_in.
void IPAddrToInAddr( const CIPAddr *pIn, in_addr *pOut )
{
u_char *p = (u_char*)pOut;
p[0] = pIn->ip[0];
p[1] = pIn->ip[1];
p[2] = pIn->ip[2];
p[3] = pIn->ip[3];
}
// Convert a CIPAddr to a sockaddr_in.
void IPAddrToSockAddr( const CIPAddr *pIn, struct sockaddr_in *pOut )
{
memset( pOut, 0, sizeof(*pOut) );
pOut->sin_family = AF_INET;
pOut->sin_port = htons( pIn->port );
IPAddrToInAddr( pIn, &pOut->sin_addr );
}
// Convert a CIPAddr to a sockaddr_in.
void SockAddrToIPAddr( const struct sockaddr_in *pIn, CIPAddr *pOut )
{
const u_char *p = (const u_char*)&pIn->sin_addr;
pOut->ip[0] = p[0];
pOut->ip[1] = p[1];
pOut->ip[2] = p[2];
pOut->ip[3] = p[3];
pOut->port = ntohs( pIn->sin_port );
}
class CIPSocket : public ISocket
{
public:
CIPSocket()
{
m_Socket = INVALID_SOCKET;
m_bSetupToBroadcast = false;
}
virtual ~CIPSocket()
{
Term();
}
// ISocket implementation.
public:
virtual void Release()
{
delete this;
}
virtual bool CreateSocket()
{
// Clear any old socket we had around.
Term();
// Create a socket to send and receive through.
SOCKET sock = socket( AF_INET, SOCK_DGRAM, IPPROTO_IP );
if ( sock == INVALID_SOCKET )
{
Assert( false );
return false;
}
// Nonblocking please..
int status;
DWORD val = 1;
status = ioctlsocket( sock, FIONBIO, &val );
if ( status != 0 )
{
assert( false );
closesocket( sock );
return false;
}
m_Socket = sock;
return true;
}
// Called after we have a socket.
virtual bool BindPart2( const CIPAddr *pAddr )
{
Assert( m_Socket != INVALID_SOCKET );
// bind to it!
sockaddr_in addr;
IPAddrToSockAddr( pAddr, &addr );
int status = bind( m_Socket, (sockaddr*)&addr, sizeof(addr) );
if ( status == 0 )
{
return true;
}
else
{
Term();
return false;
}
}
virtual bool Bind( const CIPAddr *pAddr )
{
if ( !CreateSocket() )
return false;
return BindPart2( pAddr );
}
virtual bool BindToAny( const unsigned short port )
{
// (INADDR_ANY)
CIPAddr addr;
addr.ip[0] = addr.ip[1] = addr.ip[2] = addr.ip[3] = 0;
addr.port = port;
return Bind( &addr );
}
virtual bool ListenToMulticastStream( const CIPAddr &addr, const CIPAddr &localInterface )
{
ip_mreq mr;
IPAddrToInAddr( &addr, &mr.imr_multiaddr );
IPAddrToInAddr( &localInterface, &mr.imr_interface );
// This helps a lot if the stream is sending really fast.
int rcvBuf = 1024*1024*2;
setsockopt( m_Socket, SOL_SOCKET, SO_RCVBUF, (char*)&rcvBuf, sizeof( rcvBuf ) );
if ( setsockopt( m_Socket, IPPROTO_IP, IP_ADD_MEMBERSHIP, (char*)&mr, sizeof( mr ) ) == 0 )
{
// Remember this so we do IP_DEL_MEMBERSHIP on shutdown.
m_bMulticastGroupMembership = true;
m_MulticastGroupMREQ = mr;
return true;
}
else
{
return false;
}
}
virtual bool Broadcast( const void *pData, const int len, const unsigned short port )
{
assert( m_Socket != INVALID_SOCKET );
// Make sure we're setup to broadcast.
if ( !m_bSetupToBroadcast )
{
BOOL bBroadcast = true;
if ( setsockopt( m_Socket, SOL_SOCKET, SO_BROADCAST, (char*)&bBroadcast, sizeof( bBroadcast ) ) != 0 )
{
assert( false );
return false;
}
m_bSetupToBroadcast = true;
}
CIPAddr addr;
addr.ip[0] = addr.ip[1] = addr.ip[2] = addr.ip[3] = 0xFF;
addr.port = port;
return SendTo( &addr, pData, len );
}
virtual bool SendTo( const CIPAddr *pAddr, const void *pData, const int len )
{
return SendChunksTo( pAddr, &pData, &len, 1 );
}
virtual bool SendChunksTo( const CIPAddr *pAddr, void const * const *pChunks, const int *pChunkLengths, int nChunks )
{
WSABUF bufs[32];
if ( nChunks > 32 )
{
Error( "CIPSocket::SendChunksTo: too many chunks (%d).", nChunks );
}
int nTotalBytes = 0;
for ( int i=0; i < nChunks; i++ )
{
bufs[i].len = pChunkLengths[i];
bufs[i].buf = (char*)pChunks[i];
nTotalBytes += pChunkLengths[i];
}
assert( m_Socket != INVALID_SOCKET );
// Translate the address.
sockaddr_in addr;
IPAddrToSockAddr( pAddr, &addr );
DWORD dwNumBytesSent = 0;
DWORD ret = WSASendTo(
m_Socket,
bufs,
nChunks,
&dwNumBytesSent,
0,
(sockaddr*)&addr,
sizeof( addr ),
NULL,
NULL
);
return ret == 0 && (int)dwNumBytesSent == nTotalBytes;
}
virtual int RecvFrom( void *pData, int maxDataLen, CIPAddr *pFrom )
{
assert( m_Socket != INVALID_SOCKET );
fd_set readSet;
readSet.fd_count = 1;
readSet.fd_array[0] = m_Socket;
TIMEVAL timeVal = SetupTimeVal( 0 );
// See if it has a packet waiting.
int status = select( 0, &readSet, NULL, NULL, &timeVal );
if ( status == 0 || status == SOCKET_ERROR )
return -1;
// Get the data.
sockaddr_in sender;
int fromSize = sizeof( sockaddr_in );
status = recvfrom( m_Socket, (char*)pData, maxDataLen, 0, (struct sockaddr*)&sender, &fromSize );
if ( status == 0 || status == SOCKET_ERROR )
{
return -1;
}
else
{
if ( pFrom )
{
SockAddrToIPAddr( &sender, pFrom );
}
m_flLastRecvTime = IP_FloatTime();
return status;
}
}
virtual double GetRecvTimeout()
{
return IP_FloatTime() - m_flLastRecvTime;
}
private:
void Term()
{
if ( m_Socket != INVALID_SOCKET )
{
if ( m_bMulticastGroupMembership )
{
// Undo our multicast group membership.
setsockopt( m_Socket, IPPROTO_IP, IP_DROP_MEMBERSHIP, (char*)&m_MulticastGroupMREQ, sizeof( m_MulticastGroupMREQ ) );
}
closesocket( m_Socket );
m_Socket = INVALID_SOCKET;
}
m_bSetupToBroadcast = false;
m_bMulticastGroupMembership = false;
}
private:
SOCKET m_Socket;
bool m_bMulticastGroupMembership; // Did we join a multicast group?
ip_mreq m_MulticastGroupMREQ;
bool m_bSetupToBroadcast;
double m_flLastRecvTime;
bool m_bListenSocket;
};
ISocket* CreateIPSocket()
{
return new CIPSocket;
}
ISocket* CreateMulticastListenSocket(
const CIPAddr &addr,
const CIPAddr &localInterface )
{
CIPSocket *pSocket = new CIPSocket;
CIPAddr bindAddr = localInterface;
bindAddr.port = addr.port;
if ( pSocket->Bind( &bindAddr ) &&
pSocket->ListenToMulticastStream( addr, localInterface )
)
{
return pSocket;
}
else
{
pSocket->Release();
return NULL;
}
}
bool ConvertStringToIPAddr( const char *pStr, CIPAddr *pOut )
{
char ipStr[512];
const char *pColon = strchr( pStr, ':' );
if ( pColon )
{
int toCopy = pColon - pStr;
if ( toCopy < 2 || toCopy > sizeof(ipStr)-1 )
{
assert( false );
return false;
}
memcpy( ipStr, pStr, toCopy );
ipStr[toCopy] = 0;
pOut->port = (unsigned short)atoi( pColon+1 );
}
else
{
strncpy( ipStr, pStr, sizeof( ipStr ) );
ipStr[ sizeof(ipStr)-1 ] = 0;
}
if ( ipStr[0] >= '0' && ipStr[0] <= '9' )
{
// It's numbers.
int ip[4];
sscanf( ipStr, "%d.%d.%d.%d", &ip[0], &ip[1], &ip[2], &ip[3] );
pOut->ip[0] = (unsigned char)ip[0];
pOut->ip[1] = (unsigned char)ip[1];
pOut->ip[2] = (unsigned char)ip[2];
pOut->ip[3] = (unsigned char)ip[3];
}
else
{
// It's a text string.
struct hostent *pHost = gethostbyname( ipStr );
if( !pHost )
return false;
pOut->ip[0] = pHost->h_addr_list[0][0];
pOut->ip[1] = pHost->h_addr_list[0][1];
pOut->ip[2] = pHost->h_addr_list[0][2];
pOut->ip[3] = pHost->h_addr_list[0][3];
}
return true;
}
bool ConvertIPAddrToString( const CIPAddr *pIn, char *pOut, int outLen )
{
in_addr addr;
addr.S_un.S_un_b.s_b1 = pIn->ip[0];
addr.S_un.S_un_b.s_b2 = pIn->ip[1];
addr.S_un.S_un_b.s_b3 = pIn->ip[2];
addr.S_un.S_un_b.s_b4 = pIn->ip[3];
HOSTENT *pEnt = gethostbyaddr( (char*)&addr, sizeof( addr ), AF_INET );
if ( pEnt )
{
Q_strncpy( pOut, pEnt->h_name, outLen );
return true;
}
else
{
return false;
}
}
void IP_GetLastErrorString( char *pStr, int maxLen )
{
char *lpMsgBuf;
FormatMessage(
FORMAT_MESSAGE_ALLOCATE_BUFFER |
FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL,
GetLastError(),
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
(LPTSTR) &lpMsgBuf,
0,
NULL
);
Q_strncpy( pStr, lpMsgBuf, maxLen );
LocalFree( lpMsgBuf );
}