diff --git a/external/sourcesdk/bitbuf.cpp b/external/sourcesdk/bitbuf.cpp index 2b3ded6..6c7e37a 100644 --- a/external/sourcesdk/bitbuf.cpp +++ b/external/sourcesdk/bitbuf.cpp @@ -58,27 +58,11 @@ inline unsigned int CountTrailingZeros(unsigned int elem) #define FAST_BIT_SCAN 0 #endif - -static BitBufErrorHandler g_BitBufErrorHandler = 0; - inline int BitForBitnum(int bitnum) { return GetBitForBitnum(bitnum); } -void InternalBitBufErrorHandler( BitBufErrorType errorType, const char *pDebugName ) -{ - if ( g_BitBufErrorHandler ) - g_BitBufErrorHandler( errorType, pDebugName ); -} - - -void SetBitBufErrorHandler( BitBufErrorHandler fn ) -{ - g_BitBufErrorHandler = fn; -} - - // #define BB_PROFILING unsigned long g_LittleBits[32]; @@ -130,6 +114,7 @@ bf_write::bf_write() m_bOverflow = false; m_bAssertOnOverflow = true; m_pDebugName = NULL; + m_errorHandler = NULL; } bf_write::bf_write( const char *pDebugName, void *pData, int nBytes, int nBits ) @@ -196,6 +181,19 @@ void bf_write::SetDebugName( const char *pDebugName ) m_pDebugName = pDebugName; } +void bf_write::SetErrorHandler(IBitBufOverErrorHandler* handler) +{ + m_errorHandler = handler; +} + +bool bf_write::CallErrorHandler(BitBufErrorType errorType) +{ + if (m_errorHandler) + { + return m_errorHandler->HandleError(errorType, GetDebugName()); + } + return false; +} void bf_write::SeekToBit( int bitPos ) { @@ -452,11 +450,14 @@ bool bf_write::WriteBits(const void *pInData, int nBits) int nBitsLeft = nBits; // Bounds checking.. - if ( (m_iCurBit+nBits) > m_nDataBits ) + if (GetNumBitsLeft() < nBits) { - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); - return false; + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && (GetNumBitsLeft() < nBits))) + { + SetOverflowFlag(); + return false; + } } // Align output to dword boundary @@ -794,6 +795,7 @@ bf_read::bf_read() m_bOverflow = false; m_bAssertOnOverflow = true; m_pDebugName = NULL; + m_errorHandler = NULL; } bf_read::bf_read( const void *pData, int nBytes, int nBits ) @@ -847,6 +849,20 @@ void bf_read::SetDebugName( const char *pName ) m_pDebugName = pName; } +void bf_read::SetErrorHandler(IBitBufOverErrorHandler* handler) +{ + m_errorHandler = handler; +} + +bool bf_read::CallErrorHandler(BitBufErrorType errorType) +{ + if (m_errorHandler) + { + return m_errorHandler->HandleError(errorType, GetDebugName()); + } + return false; +} + void bf_read::SetOverflowFlag() { if ( m_bAssertOnOverflow ) diff --git a/external/sourcesdk/include/sourcesdk/bitbuf.h b/external/sourcesdk/include/sourcesdk/bitbuf.h index 5cce666..98e6642 100644 --- a/external/sourcesdk/include/sourcesdk/bitbuf.h +++ b/external/sourcesdk/include/sourcesdk/bitbuf.h @@ -36,21 +36,11 @@ typedef enum BITBUFERROR_NUM_ERRORS } BitBufErrorType; - -typedef void (*BitBufErrorHandler)( BitBufErrorType errorType, const char *pDebugName ); - - -#if defined( _DEBUG ) - extern void InternalBitBufErrorHandler( BitBufErrorType errorType, const char *pDebugName ); - #define CallErrorHandler( errorType, pDebugName ) InternalBitBufErrorHandler( errorType, pDebugName ); -#else - #define CallErrorHandler( errorType, pDebugName ) -#endif - - -// Use this to install the error handler. Call with NULL to uninstall your error handler. -void SetBitBufErrorHandler( BitBufErrorHandler fn ); - +class IBitBufOverErrorHandler +{ +public: + virtual bool HandleError(BitBufErrorType errorType, const char *pDebugName) = 0; +}; //----------------------------------------------------------------------------- // Helpers. @@ -150,6 +140,8 @@ public: const char* GetDebugName(); void SetDebugName( const char *pDebugName ); + void SetErrorHandler(IBitBufOverErrorHandler* handler); + bool CallErrorHandler(BitBufErrorType errorType); // Seek to a specific position. public: @@ -253,6 +245,7 @@ private: bool m_bAssertOnOverflow; const char *m_pDebugName; + IBitBufOverErrorHandler* m_errorHandler; }; @@ -298,10 +291,13 @@ inline const unsigned char* bf_write::GetData() const inline bool bf_write::CheckForOverflow(int nBits) { - if ( m_iCurBit + nBits > m_nDataBits ) + if (GetNumBitsLeft() < nBits) { - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && GetNumBitsLeft() < nBits)) + { + SetOverflowFlag(); + } } return m_bOverflow; @@ -340,9 +336,12 @@ inline void bf_write::WriteOneBit(int nValue) { if( m_iCurBit >= m_nDataBits ) { - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); - return; + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && (m_iCurBit >= m_nDataBits))) + { + SetOverflowFlag(); + return; + } } WriteOneBitNoCheck( nValue ); } @@ -352,9 +351,12 @@ inline void bf_write::WriteOneBitAt( int iBit, int nValue ) { if( iBit >= m_nDataBits ) { - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); - return; + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && (iBit >= m_nDataBits))) + { + SetOverflowFlag(); + return; + } } #if __i386__ @@ -379,7 +381,7 @@ inline void bf_write::WriteUBitLong( unsigned int curData, int numbits, bool bCh { if ( curData >= (unsigned long)(1 << numbits) ) { - CallErrorHandler( BITBUFERROR_VALUE_OUT_OF_RANGE, GetDebugName() ); + CallErrorHandler(BITBUFERROR_VALUE_OUT_OF_RANGE); } } Assert( numbits >= 0 && numbits <= 32 ); @@ -387,10 +389,12 @@ inline void bf_write::WriteUBitLong( unsigned int curData, int numbits, bool bCh if ( GetNumBitsLeft() < numbits ) { - m_iCurBit = m_nDataBits; - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); - return; + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && (GetNumBitsLeft() < numbits))) + { + SetOverflowFlag(); + return; + } } int iCurBitMasked = m_iCurBit & 31; @@ -504,6 +508,9 @@ public: const char* GetDebugName() const { return m_pDebugName; } void SetDebugName( const char *pName ); + void SetErrorHandler(IBitBufOverErrorHandler* handler); + bool CallErrorHandler(BitBufErrorType errorType); + void ExciseBits( int startbit, int bitstoremove ); @@ -651,6 +658,7 @@ private: bool m_bAssertOnOverflow; const char *m_pDebugName; + IBitBufOverErrorHandler* m_errorHandler; }; //----------------------------------------------------------------------------- @@ -700,10 +708,13 @@ inline bool bf_read::SeekRelative(int iBitDelta) inline bool bf_read::CheckForOverflow(int nBits) { - if( m_iCurBit + nBits > m_nDataBits ) + if (GetNumBitsLeft() < nBits) { - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && (GetNumBitsLeft() < nBits))) + { + SetOverflowFlag(); + } } return m_bOverflow; @@ -730,9 +741,12 @@ inline int bf_read::ReadOneBit() { if( GetNumBitsLeft() <= 0 ) { - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); - return 0; + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && (GetNumBitsLeft() <= 0))) + { + SetOverflowFlag(); + return 0; + } } return ReadOneBitNoCheck(); } @@ -762,10 +776,12 @@ inline unsigned int bf_read::ReadUBitLong( int numbits ) __restrict if ( GetNumBitsLeft() < numbits ) { - m_iCurBit = m_nDataBits; - SetOverflowFlag(); - CallErrorHandler( BITBUFERROR_BUFFER_OVERRUN, GetDebugName() ); - return 0; + const bool recovered = CallErrorHandler(BITBUFERROR_BUFFER_OVERRUN); + if (!recovered || (recovered && (GetNumBitsLeft() < numbits))) + { + SetOverflowFlag(); + return 0; + } } unsigned int iStartBit = m_iCurBit & 31u;