hl2_src-leak-2017/src/networksystem/networksystem.cpp

492 lines
13 KiB
C++

//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose:
//
//===========================================================================//
#include "networksystem.h"
#include "filesystem.h"
#include "UDP_Socket.h"
#include "sm_protocol.h"
#include "NetChannel.h"
#include "UDP_Process.h"
#include <winsock.h>
#include "networkclient.h"
#include "networkserver.h"
#include "networksystem/inetworkmessage.h"
#include "mathlib/mathlib.h"
#include "tier2/tier2.h"
//-----------------------------------------------------------------------------
// Singleton instance
//-----------------------------------------------------------------------------
static CNetworkSystem g_NetworkSystem;
CNetworkSystem *g_pNetworkSystemImp = &g_NetworkSystem;
EXPOSE_SINGLE_INTERFACE_GLOBALVAR( CNetworkSystem, INetworkSystem,
NETWORKSYSTEM_INTERFACE_VERSION, g_NetworkSystem );
//-----------------------------------------------------------------------------
// Constructor, destructor
//-----------------------------------------------------------------------------
CNetworkSystem::CNetworkSystem()
{
m_bWinsockInitialized = false;
m_bNetworkEventCreated = false;
m_bInMidPacket = false;
m_pServer = NULL;
m_pClient = NULL;
m_nGroupBits = 1;
m_nTypeBits = LargestPowerOfTwoLessThanOrEqual( net_num_messages );
}
CNetworkSystem::~CNetworkSystem()
{
}
//-----------------------------------------------------------------------------
// Initialization, shutdown
//-----------------------------------------------------------------------------
InitReturnVal_t CNetworkSystem::Init()
{
InitReturnVal_t nRetVal = BaseClass::Init();
if ( nRetVal != INIT_OK )
return nRetVal;
// initialize winsock 2.0
WSAData wsaData;
if ( WSAStartup( MAKEWORD(2,0), &wsaData ) != 0 )
{
Warning( "Error! Failed to load network socket library.\n");
return INIT_OK;
}
else
{
m_bWinsockInitialized = true;
}
LPHOSTENT lp = gethostbyname("localhost");
if ( !lp )
{
Warning( "Error! Failed to query local host info\n");
return INIT_OK;
}
m_LocalHostName = lp->h_name;
lp = gethostbyname( m_LocalHostName );
if ( !lp )
{
Warning( "Error! Failed to query local host info\n");
return INIT_OK;
}
sockaddr ip;
m_LocalAddressString = inet_ntoa( *((in_addr*)(lp->h_addr_list[0])) );
StringToSockaddr( m_LocalAddressString, &ip );
m_LocalAddress.SetFromSockadr( &ip );
return INIT_OK;
}
void CNetworkSystem::Shutdown()
{
if ( m_bWinsockInitialized )
{
WSACleanup();
}
CleanupNetworkMessages();
BaseClass::Shutdown();
}
//-----------------------------------------------------------------------------
// Connect, disconnect
//-----------------------------------------------------------------------------
bool CNetworkSystem::Connect( CreateInterfaceFn factory )
{
if ( !BaseClass::Connect( factory ) )
return false;
if ( !g_pFullFileSystem )
{
Warning( "The network system requires the filesystem to run!\n" );
return false;
}
return INIT_OK;
}
//-----------------------------------------------------------------------------
// Returns the current time
//-----------------------------------------------------------------------------
float CNetworkSystem::GetTime( void )
{
return Plat_FloatTime();
}
//-----------------------------------------------------------------------------
// Installs network message factories to be used with all connections
//-----------------------------------------------------------------------------
bool CNetworkSystem::RegisterMessage( INetworkMessage *pMessage )
{
if ( m_pServer || m_pClient )
{
Warning( "Cannot register messages after connection has started.\n" );
return false;
}
if ( pMessage->GetGroup() == 0 )
{
Warning( "Network message group 0 is reserved by the network system.\n" );
return false;
}
// Look for already registered message
if ( m_NetworkMessages.Find( pMessage ) >= 0 )
return false;
// Allocate more space in messages
int nGroupBits = LargestPowerOfTwoLessThanOrEqual( pMessage->GetGroup() );
int nTypeBits = LargestPowerOfTwoLessThanOrEqual( pMessage->GetType() );
m_nGroupBits = max( nGroupBits, m_nGroupBits );
m_nTypeBits = max( nTypeBits, m_nTypeBits );
m_NetworkMessages.AddToTail( pMessage );
return true;
}
void CNetworkSystem::CleanupNetworkMessages( )
{
int nCount = m_NetworkMessages.Count();
for ( int i = 0; i < nCount; ++i )
{
m_NetworkMessages[i]->Release();
}
m_NetworkMessages.RemoveAll();
}
//-----------------------------------------------------------------------------
// Finds a network message given a particular message type
//-----------------------------------------------------------------------------
INetworkMessage* CNetworkSystem::FindNetworkMessage( int group, int type )
{
int nCount = m_NetworkMessages.Count();
for (int i=0; i < nCount; i++ )
{
if ( ( m_NetworkMessages[i]->GetGroup() == group ) && ( m_NetworkMessages[i]->GetType() == type ) )
return m_NetworkMessages[i];
}
return NULL;
}
//-----------------------------------------------------------------------------
// Purpose:
//-----------------------------------------------------------------------------
bool CNetworkSystem::StringToSockaddr( const char *s, struct sockaddr *sadr )
{
struct hostent *h;
char *colon;
char copy[128];
Q_memset (sadr, 0, sizeof(*sadr));
((struct sockaddr_in *)sadr)->sin_family = AF_INET;
((struct sockaddr_in *)sadr)->sin_port = 0;
Q_strncpy (copy, s, sizeof( copy ) );
// strip off a trailing :port if present
for (colon = copy ; *colon ; colon++)
if (*colon == ':')
{
*colon = 0;
((struct sockaddr_in *)sadr)->sin_port = htons((short)atoi(colon+1));
}
if (copy[0] >= '0' && copy[0] <= '9' && Q_strstr( copy, "." ) )
{
*(int *)&((struct sockaddr_in *)sadr)->sin_addr = inet_addr(copy);
}
else
{
// if ( net_nodns )
// return false; // DNS names disabled
if ( (h = gethostbyname(copy)) == NULL )
return false;
*(int *)&((struct sockaddr_in *)sadr)->sin_addr = *(int *)h->h_addr_list[0];
}
return true;
}
//-----------------------------------------------------------------------------
// Returns the local address
//-----------------------------------------------------------------------------
const char* CNetworkSystem::GetLocalHostName( void ) const
{
return m_LocalHostName;
}
const char* CNetworkSystem::GetLocalAddress( void ) const
{
return m_LocalAddressString;
}
//-----------------------------------------------------------------------------
// Start, shutdown a server
//-----------------------------------------------------------------------------
bool CNetworkSystem::StartServer( unsigned short nServerListenPort )
{
if ( !m_bWinsockInitialized )
return false;
Assert( !m_pServer );
m_pServer = new CNetworkServer;
return m_pServer->Init( nServerListenPort );
}
void CNetworkSystem::ShutdownServer( )
{
if ( m_pServer )
{
m_pServer->Shutdown();
delete m_pServer;
m_pServer = NULL;
}
}
//-----------------------------------------------------------------------------
// Server update
//-----------------------------------------------------------------------------
void CNetworkSystem::ServerReceiveMessages()
{
if ( m_pServer )
{
m_pServer->ReadPackets();
}
}
void CNetworkSystem::ServerSendMessages()
{
if ( m_pServer )
{
m_pServer->SendUpdates();
}
}
//-----------------------------------------------------------------------------
// Start, shutdown a client
//-----------------------------------------------------------------------------
bool CNetworkSystem::StartClient( unsigned short nClientListenPort )
{
if ( !m_bWinsockInitialized )
return false;
Assert( !m_pClient );
m_pClient = new CNetworkClient;
return m_pClient->Init( nClientListenPort );
}
void CNetworkSystem::ShutdownClient( )
{
if ( m_pClient )
{
m_pClient->Shutdown();
delete m_pClient;
m_pClient = NULL;
}
}
//-----------------------------------------------------------------------------
// Server update
//-----------------------------------------------------------------------------
void CNetworkSystem::ClientReceiveMessages()
{
if ( m_pClient )
{
m_pClient->ReadPackets();
}
}
void CNetworkSystem::ClientSendMessages()
{
if ( m_pClient )
{
m_pClient->SendUpdate();
}
}
//-----------------------------------------------------------------------------
// Server update
//-----------------------------------------------------------------------------
INetChannel* CNetworkSystem::ConnectClientToServer( const char *pServer, int nServerListenPort )
{
if ( m_pClient )
{
if ( m_pClient->Connect( pServer, nServerListenPort ) )
return m_pClient->GetNetChannel();
}
return NULL;
}
void CNetworkSystem::DisconnectClientFromServer( INetChannel* pChannel )
{
if ( m_pClient && ( m_pClient->GetNetChannel() == pChannel ) )
{
m_pClient->Disconnect();
}
}
//-----------------------------------------------------------------------------
// Queues up a network packet
//-----------------------------------------------------------------------------
void CNetworkSystem::EnqueueConnectionlessNetworkPacket( CNetPacket *pPacket, IConnectionlessPacketHandler *pHandler )
{
int i = m_PacketQueue.AddToTail( );
PacketInfo_t& info = m_PacketQueue[i];
info.m_pPacket = pPacket;
info.m_pHandler = pHandler;
info.m_pNetChannel = NULL;
}
void CNetworkSystem::EnqueueNetworkPacket( CNetPacket *pPacket, CNetChannel *pNetChannel )
{
int i = m_PacketQueue.AddToTail( );
PacketInfo_t& info = m_PacketQueue[i];
info.m_pPacket = pPacket;
info.m_pHandler = NULL;
info.m_pNetChannel = pNetChannel;
}
//-----------------------------------------------------------------------------
// Network event iteration helpers
//-----------------------------------------------------------------------------
bool CNetworkSystem::StartProcessingNewPacket()
{
PacketInfo_t& info = m_PacketQueue[ m_nProcessingPacket ];
if ( info.m_pHandler )
{
UDP_ProcessConnectionlessPacket( info.m_pPacket, info.m_pHandler );
return false;
}
if ( !info.m_pNetChannel )
{
// Not an error that may happen during connect or disconnect
Warning( "Sequenced packet without connection from %s\n" , info.m_pPacket->m_From.ToString() );
return false;
}
return info.m_pNetChannel->StartProcessingPacket( info.m_pPacket );
}
bool CNetworkSystem::AdvanceProcessingNetworkPacket( )
{
m_PacketQueue[ m_nProcessingPacket ].m_pPacket->Release();
m_PacketQueue[ m_nProcessingPacket ].m_pPacket = NULL;
++m_nProcessingPacket;
bool bOverflowed = ( m_nProcessingPacket >= m_PacketQueue.Count() );
if ( bOverflowed )
{
m_PacketQueue.RemoveAll();
m_nProcessingPacket = 0;
}
return !bOverflowed;
}
//-----------------------------------------------------------------------------
// Network event iteration
//-----------------------------------------------------------------------------
NetworkEvent_t *CNetworkSystem::FirstNetworkEvent( )
{
Assert( !m_bInMidPacket && !m_nProcessingPacket );
m_nProcessingPacket = 0;
m_bInMidPacket = false;
return NextNetworkEvent();
}
NetworkEvent_t *CNetworkSystem::NextNetworkEvent( )
{
int nPacketCount = m_PacketQueue.Count();
if ( m_nProcessingPacket >= nPacketCount )
return NULL;
while( true )
{
// Continue processing the packet we're currently on
if ( m_bInMidPacket )
{
PacketInfo_t& info = m_PacketQueue[ m_nProcessingPacket ];
while( info.m_pNetChannel->ProcessPacket( info.m_pPacket ) )
{
Assert( m_bNetworkEventCreated );
m_bNetworkEventCreated = false;
return (NetworkEvent_t*)m_EventMessageBuffer;
}
info.m_pNetChannel->EndProcessingPacket( info.m_pPacket );
m_bInMidPacket = false;
if ( !AdvanceProcessingNetworkPacket() )
return NULL;
}
// Keep reading packets until we find one that either generates an event,
// one that encounters a packet we need to process, or we exhaust the event queue
while( !StartProcessingNewPacket() )
{
bool bOverflowed = !AdvanceProcessingNetworkPacket();
if ( m_bNetworkEventCreated )
{
m_bNetworkEventCreated = false;
return (NetworkEvent_t*)m_EventMessageBuffer;
}
if ( bOverflowed )
return NULL;
}
m_bInMidPacket = true;
}
}
//-----------------------------------------------------------------------------
// Network event creation
//-----------------------------------------------------------------------------
NetworkEvent_t* CNetworkSystem::CreateNetworkEvent( int nSizeInBytes )
{
Assert( nSizeInBytes <= sizeof( m_EventMessageBuffer ) );
// If this assertion fails, it means two or more network events were created
// before the main network event loop had a chance to inform external code
Assert( !m_bNetworkEventCreated );
m_bNetworkEventCreated = true;
return ( NetworkEvent_t* )m_EventMessageBuffer;
}
bool CNetworkSystem::IsNetworkEventCreated()
{
return m_bNetworkEventCreated;
}