//========= Copyright Valve Corporation, All rights reserved. ============// // // Purpose: Provides access to SQL at a high level // //============================================================================= #include "stdafx.h" #include "gcsdk/sqlaccess/sqlaccess.h" #include "gcsdk/gcsqlquery.h" // memdbgon must be the last include file in a .cpp file!!! #include "tier0/memdbgon.h" template< typename LISTENER_FUNC > static void RunAndClearListenerList( std::vector< LISTENER_FUNC > &vecListeners ) { // Let us not underestimate the ability of random listeners to re-enter everything. std::vector< LISTENER_FUNC > listenerCopy; listenerCopy.swap( vecListeners ); vecListeners.clear(); // Why would you consider such a thing DO_NOT_YIELD_THIS_SCOPE(); for ( const auto &listener : listenerCopy ) { listener(); } } namespace GCSDK { //------------------------------------------------------------------------------------ // Purpose: Constructor //------------------------------------------------------------------------------------ CSQLAccess::CSQLAccess( ESchemaCatalog eSchemaCatalog ) : m_eSchemaCatalog( eSchemaCatalog) , m_pCurrentQuery( NULL ) , m_bInTransaction( false ) { m_pQueryGroup = CGCSQLQueryGroup::Alloc(); } //------------------------------------------------------------------------------------ // Purpose: Destructor //------------------------------------------------------------------------------------ CSQLAccess::~CSQLAccess( ) { SAFE_RELEASE( m_pQueryGroup ); Assert( !m_pCurrentQuery ); SAFE_DELETE( m_pCurrentQuery ); AssertMsg( !m_bInTransaction, "GCSDK::CSQLAccess object being destroyed with a transaction pending. Use BCommitTransaction or RollbackTransaction to match your BBeginTransaction call." ); } //------------------------------------------------------------------------------------ // Purpose: Perform a query //------------------------------------------------------------------------------------ bool CSQLAccess::BYieldingExecute( const char *pchName, const char *pchSQLCommand, uint32 *pcRowsAffected, bool bSpewOnError ) { if ( NULL == pchName ) { pchName = pchSQLCommand; } bool bStandalone = !BInTransaction(); if( bStandalone ) { BBeginTransaction( pchName ); } CurrentQuery()->SetCommand( pchSQLCommand ); m_pQueryGroup->AddQuery( m_pCurrentQuery ); m_pCurrentQuery = NULL; bool bSuccess = true; if( bStandalone ) { bSuccess = BCommitTransaction(); if( bSuccess && pcRowsAffected ) { *pcRowsAffected = m_pQueryGroup->GetResults()->GetRowsAffected( 0 ); } } return bSuccess; } //------------------------------------------------------------------------------------ // Purpose: Starts a transaction //------------------------------------------------------------------------------------ bool CSQLAccess::BBeginTransaction( const char *pchName ) { Assert( !m_bInTransaction ); if( m_bInTransaction ) return false; m_pQueryGroup->Clear(); m_pQueryGroup->SetName( pchName ); m_bInTransaction = true; return true; } //------------------------------------------------------------------------------------ // Purpose: Returns the string last passed to BBeginTransaction //------------------------------------------------------------------------------------ const char *CSQLAccess::PchTransactionName( ) const { return m_pQueryGroup->PchName(); } //------------------------------------------------------------------------------------ // Purpose: Commits a transaction to the database //------------------------------------------------------------------------------------ bool CSQLAccess::BCommitTransaction( bool bAllowEmpty ) { Assert( BInTransaction() ); if( !BInTransaction() ) return false; if( !m_pCurrentQuery && !m_pQueryGroup->GetStatementCount() ) { if( bAllowEmpty ) { // No-op success m_bInTransaction = false; RunListeners_Commit(); return true; } else { AssertMsg1( false, "BCommitTransaction with empty transaction at %s", m_pQueryGroup->PchName() ); return false; } } AssertMsg1( !m_pCurrentQuery, "Unexecuted query present in BCommitTransaction: %s", m_pCurrentQuery->PchCommand() ); if( m_pCurrentQuery ) return false; m_bInTransaction = false; if( !GJobCur().BYieldingRunQuery( m_pQueryGroup, m_eSchemaCatalog ) ) { // Notify listeners that the transaction did not succeed RunListeners_Rollback(); return false; } // The transaction presumably did make the database, so we do not notify rollback listeners beyond here. RunListeners_Commit(); if( !m_pQueryGroup->GetResults() ) return false; return true; } //------------------------------------------------------------------------------------ // Purpose: Rolls back a transaction and clears any queries //------------------------------------------------------------------------------------ void CSQLAccess::RollbackTransaction() { bool bWasTransaction = BInTransaction(); Assert( bWasTransaction ); SAFE_DELETE( m_pCurrentQuery ); m_bInTransaction = false; if ( bWasTransaction ) { RunListeners_Rollback(); } else { m_vecCommitListeners.clear(); m_vecRollbackListeners.clear(); } } //------------------------------------------------------------------------------------ // Purpose: Adds a listener to be called synchronously should the transaction successfully commit //------------------------------------------------------------------------------------ void CSQLAccess::AddCommitListener( std::function &&listener ) { if ( !BInTransaction() ) { AssertMsg( BInTransaction(), "Adding a listener to a non-transaction access, will never fire" ); return; } m_vecCommitListeners.push_back( std::move( listener ) ); } //------------------------------------------------------------------------------------ // Purpose: Adds a listener to be called synchronously should the transaction fail or explicitly rollback //------------------------------------------------------------------------------------ void CSQLAccess::AddRollbackListener( std::function &&listener ) { if ( !BInTransaction() ) { AssertMsg( BInTransaction(), "Adding a listener to a non-transaction access, will never fire" ); return; } m_vecRollbackListeners.push_back( std::move( listener ) ); } //------------------------------------------------------------------------------------ // Purpose: Notifies listeners of successful commit. //------------------------------------------------------------------------------------ void CSQLAccess::RunListeners_Commit() { RunAndClearListenerList( m_vecCommitListeners ); // Clear the unused set m_vecRollbackListeners.clear(); } //------------------------------------------------------------------------------------ // Purpose: Notifies listeners of a implicitly or explicitly rolled back transactions and clears the listener list. //------------------------------------------------------------------------------------ void CSQLAccess::RunListeners_Rollback() { RunAndClearListenerList( m_vecRollbackListeners ); // Clear the unused set m_vecCommitListeners.clear(); } //------------------------------------------------------------------------------------ // Purpose: Perform a query that returns a single string //------------------------------------------------------------------------------------ CSQLAccess::EReadSingleResultResult CSQLAccess::BYieldingExecuteSingleResultDataInternal( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, uint8 **ppubData, uint32 *punSize, uint32 *pcRowsAffected, bool bHasDefaultValue ) { AssertMsg( !BInTransaction(), "BYieldingExecuteSingleResultData is not supported in a transaction" ); if( BInTransaction() ) return eReadSingle_Error; bool bRet = BYieldingExecute( pchName, pchSQLCommand, pcRowsAffected ); if ( !bRet ) return eReadSingle_Error; if( m_pQueryGroup->GetResults()->GetResultSetCount() != 1 ) { AssertMsg1( false, "Expected single result set, found %d", m_pQueryGroup->GetResults()->GetResultSetCount() ); return eReadSingle_Error; } IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 ); // If we have a default value, getting back zero rows is acceptable. if( pResultSet->GetRowCount() == 0 && bHasDefaultValue ) { return eReadSingle_UseDefault; } // If we either have more than one row or no default value specified, that's an error. if( pResultSet->GetRowCount() != 1 ) { AssertMsg1( false, "Expected single result, found %d", pResultSet->GetRowCount() ); return eReadSingle_Error; } if( pResultSet->GetColumnCount() != 1 ) { AssertMsg1( false, "Expected single column, found %d", pResultSet->GetColumnCount() ); return eReadSingle_Error; } if( pResultSet->GetColumnType( 0 ) != eType ) { AssertMsg2( false, "Expected column of type %s, found %s", PchNameFromEGCSQLType( eType ), PchNameFromEGCSQLType( pResultSet->GetColumnType( 0 ) ) ); return eReadSingle_Error; } return pResultSet->GetData( 0, 0, ppubData, punSize ) ? eReadSingle_ResultFound : eReadSingle_Error; } //------------------------------------------------------------------------------------ // Purpose: Perform a query that returns a single string //------------------------------------------------------------------------------------ bool CSQLAccess::BYieldingExecuteString( const char *pchName, const char *pchSQLCommand, CFmtStr1024 *psResult, uint32 *pcRowsAffected ) { uint8 *pubData; uint32 cubData; if( CSQLAccess::BYieldingExecuteSingleResultDataInternal( pchName, pchSQLCommand, k_EGCSQLType_String, &pubData, &cubData, pcRowsAffected, false ) != eReadSingle_ResultFound ) return false; *psResult = (char *)pubData; return true; } //------------------------------------------------------------------------------------ // Purpose: Perform a query that returns a single int //------------------------------------------------------------------------------------ bool CSQLAccess::BYieldingExecuteScalarInt( const char *pchName, const char *pchSQLCommand, int *pnResult, uint32 *pcRowsAffected ) { return BYieldingExecuteSingleResult( pchName, pchSQLCommand, k_EGCSQLType_int32, pnResult, pcRowsAffected ); } bool CSQLAccess::BYieldingExecuteScalarIntWithDefault( const char *pchName, const char *pchSQLCommand, int *pnResult, int iDefaultValue, uint32 *pcRowsAffected ) { return BYieldingExecuteSingleResultWithDefault( pchName, pchSQLCommand, k_EGCSQLType_int32, pnResult, iDefaultValue, pcRowsAffected ); } //------------------------------------------------------------------------------------ // Purpose: Perform a query that returns a single uint32 //------------------------------------------------------------------------------------ bool CSQLAccess::BYieldingExecuteScalarUint32( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 *pcRowsAffected ) { return BYieldingExecuteSingleResult( pchName, pchSQLCommand, k_EGCSQLType_int32, punResult, pcRowsAffected ); } bool CSQLAccess::BYieldingExecuteScalarUint32WithDefault( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 unDefaultValue, uint32 *pcRowsAffected ) { return BYieldingExecuteSingleResultWithDefault( pchName, pchSQLCommand, k_EGCSQLType_int32, punResult, unDefaultValue, pcRowsAffected ); } //------------------------------------------------------------------------------------ // Purpose: A bunch of pass throughs to the query itself //------------------------------------------------------------------------------------ void CSQLAccess::AddBindParam( const char *pchValue ) { CurrentQuery()->AddBindParam( pchValue ); } void CSQLAccess::AddBindParam( const int16 nValue ) { CurrentQuery()->AddBindParam( nValue ); } void CSQLAccess::AddBindParam( const uint16 uValue ) { CurrentQuery()->AddBindParam( uValue ); } void CSQLAccess::AddBindParam( const int32 nValue ) { CurrentQuery()->AddBindParam( nValue ); } void CSQLAccess::AddBindParam( const uint32 uValue ) { CurrentQuery()->AddBindParam( uValue ); } void CSQLAccess::AddBindParam( const uint64 ulValue ) { CurrentQuery()->AddBindParam( ulValue ); } void CSQLAccess::AddBindParam( const uint8 *ubValue, const int cubValue ) { CurrentQuery()->AddBindParam( ubValue, cubValue ); } void CSQLAccess::AddBindParam( const float fValue ) { CurrentQuery()->AddBindParam( fValue ); } void CSQLAccess::AddBindParam( const double dValue ) { CurrentQuery()->AddBindParam( dValue ); } void CSQLAccess::AddBindParamRaw( EGCSQLType eType, const byte *pubData, uint32 cubData ) { CurrentQuery()->AddBindParamRaw( eType, pubData, cubData ); } void CSQLAccess::ClearParams() { if( m_pCurrentQuery ) { delete m_pCurrentQuery; m_pCurrentQuery = NULL; } } IGCSQLResultSetList *CSQLAccess::GetResults() { return m_pQueryGroup->GetResults(); } //------------------------------------------------------------------------------------ // Purpose: Returns the number of result sets //------------------------------------------------------------------------------------ uint32 CSQLAccess::GetResultSetCount() { if( m_pQueryGroup->GetResults() ) return m_pQueryGroup->GetResults()->GetResultSetCount(); else return 0; } //------------------------------------------------------------------------------------ // Purpose: Returns the number of rows in a result set //------------------------------------------------------------------------------------ uint32 CSQLAccess::GetResultSetRowCount( uint32 unResultSet ) { if( m_pQueryGroup->GetResults() && unResultSet < m_pQueryGroup->GetResults()->GetResultSetCount() ) return m_pQueryGroup->GetResults()->GetResultSet( unResultSet )->GetRowCount(); else return 0; } //------------------------------------------------------------------------------------ // Purpose: Returns a CSQLRecord object that represents a row in a result set //------------------------------------------------------------------------------------ CSQLRecord CSQLAccess::GetResultRecord( uint32 unResultSet, uint32 unRow ) { if( m_pQueryGroup->GetResults() && unResultSet < m_pQueryGroup->GetResults()->GetResultSetCount() ) { IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( unResultSet ); if( unRow < pResultSet->GetRowCount() ) return CSQLRecord( unRow, pResultSet ); } return CSQLRecord(); // if there was a problem return an empty record } //----------------------------------------------------------------------------- // Purpose: Inserts a new record into the DS // Input: pRecordBase - record to insert // Output: true if successful, false otherwise //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingInsertRecord( const CRecordBase *pRecordBase ) { ClearParams(); const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); int cColumns = pRecordInfo->GetNumColumns(); for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) { const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); if ( !columnInfo.BIsInsertable() ) continue; uint8 *pubData; uint32 cubData; DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); } uint32 nRows; const char *pchStatement = pRecordBase->GetPSchema()->GetInsertStatementText(); bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows ); return ( nRows == 1 || BInTransaction() ) && bRet; } //----------------------------------------------------------------------------- // Purpose: Inserts a new record into the DS if such row doesn't exist // Input: pRecordBase - record to insert // Output: true if successful, false otherwise //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingInsertWhenNotMatchedOnPK( CRecordBase *pRecordBase ) { ClearParams(); const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); int cColumns = pRecordInfo->GetNumColumns(); for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) { const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); if ( !columnInfo.BIsInsertable() ) { Assert( columnInfo.BIsInsertable() ); return false; } uint8 *pubData; uint32 cubData; DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); } uint32 nRows; const char *pchStatement = pRecordBase->GetPSchema()->GetMergeStatementTextOnPKWhenNotMatchedInsert(); bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows ); return ( nRows == 1 || nRows == 0 || BInTransaction() ) && bRet; } //----------------------------------------------------------------------------- // Purpose: Inserts a new record into the DS if such row doesn't exist // updates an existing row if such row is matched by PK // Input: pRecordBase - record to insert // Output: true if successful, false otherwise //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingInsertOrUpdateOnPK( CRecordBase *pRecordBase ) { ClearParams(); const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); int cColumns = pRecordInfo->GetNumColumns(); for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) { const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); if ( !columnInfo.BIsInsertable() ) { Assert( columnInfo.BIsInsertable() ); return false; } uint8 *pubData; uint32 cubData; DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); } uint32 nRows; const char *pchStatement = pRecordBase->GetPSchema()->GetMergeStatementTextOnPKWhenMatchedUpdateWhenNotMatchedInsert(); bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows ); return ( nRows == 1 || BInTransaction() ) && bRet; } //----------------------------------------------------------------------------- // Purpose: Inserts a new record into the DB and reads non-insertable fields back // into the record. // Input: pRecordBase - record to insert // Output: true if successful, false otherwise //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingInsertWithIdentity( CRecordBase* pRecordBase ) { AssertMsg( !BInTransaction(), "BYieldingInsertWithIdentity is not supported in a transaction" ); if( BInTransaction() ) return false; ClearParams(); TSQLCmdStr sStatement; CUtlVector vecOutputFields; CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); BuildInsertAndReadStatementText( &sStatement, &vecOutputFields, pRecordInfo ); AssertMsg( vecOutputFields.Count() > 0, "BYieldingInsertAndReadRecord called for a record type with no non-insertable columns" ); if ( vecOutputFields.Count() == 0 ) return false; int cColumns = pRecordInfo->GetNumColumns(); for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) { const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); if ( !columnInfo.BIsInsertable() ) { continue; } uint8 *pubData; uint32 cubData; DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); } bool bRet = BYieldingExecute( sStatement, sStatement ); if( !bRet ) return false; Assert( 1 == GetResultSetCount() ); if ( 1 != GetResultSetCount() ) return false; IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 ); Assert( 1 == pResultSet->GetRowCount() ); if ( 1 != pResultSet->GetRowCount() ) return false; Assert( (uint32)vecOutputFields.Count() == pResultSet->GetColumnCount() ); if ( (uint32)vecOutputFields.Count() != pResultSet->GetColumnCount() ) return false; for( uint32 nColumn = 0; nColumn < pResultSet->GetColumnCount(); nColumn++ ) { uint8 *pubData; uint32 cubData; DbgVerify( pResultSet->GetData( 0, nColumn, &pubData, &cubData ) ); int nSchColumn = vecOutputFields[nColumn]; Assert( pResultSet->GetColumnType( nColumn ) == pRecordInfo->GetColumnInfo( nSchColumn ).GetType() ); DbgVerify( pRecordBase->BSetField( nSchColumn, pubData, cubData ) ); } return true; } //----------------------------------------------------------------------------- // Purpose: Reads a list of records from the DB according to the specified where // clause // Input: pRecordBase - record to read // readSet - The set of columns to read // whereSet - The set of columns to query on // Output: true if successful, false otherwise //----------------------------------------------------------------------------- EResult CSQLAccess::YieldingReadRecordWithWhereColumns( CRecordBase *pRecord, const CColumnSet & readSet, const CColumnSet & whereSet, const char* pchOrderClause ) { AssertMsg( !BInTransaction(), "BYieldingReadRecordWithWhereColumns is not supported in a transaction" ); if( BInTransaction() ) return k_EResultInvalidState; //if there is an order by clause, only take the top one, if there isn't, then validate that we have a single instance const char* pszTopClause = ( pchOrderClause ) ? "TOP (1)" : "TOP (2)"; TSQLCmdStr sStatement; BuildSelectStatementText( &sStatement, readSet, pszTopClause ); // if we actually have some columns for the where clause, // append a where clause. if( whereSet.GetColumnCount() ) { sStatement.Append( " WHERE " ); AppendWhereClauseText( &sStatement, whereSet ); AddRecordParameters( *pRecord, whereSet ); } //append the order by if they added one if( pchOrderClause ) { sStatement.Append( " ORDER BY " ); sStatement.Append( pchOrderClause ); } Assert(!readSet.IsEmpty() ); if( !BYieldingExecute( sStatement, sStatement ) ) return k_EResultFail; if ( GetResultSetCount() != 1 ) { AssertMsg( GetResultSetCount() == 1, "Unexpected number of result sets returned from select statement" ); return k_EResultFail; } // make sure the types are the same IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 ); if ( pResultSet->GetRowCount() == 0 ) return k_EResultNoMatch; //note that since we only take the top one when there is an order by clause, we don't need to handle that case down here, only if top 2 is selected if( pResultSet->GetRowCount() != 1 ) { // Make sure we aren't failing because there are multiple matching records. // That is probably a misuse of the API or some unexpected condition. AssertMsg1( false, "BYieldingReadRecordWithWhereColumns from %s failing because multiple records match WHERE clause", readSet.GetRecordInfo()->GetName() ); return k_EResultLimitExceeded; } FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex ) { EGCSQLType eRecordType = readSet.GetColumnInfo( nColumnIndex ).GetType(); EGCSQLType eResultType = pResultSet->GetColumnType( nColumnIndex ); AssertMsg2( eResultType == eRecordType, "Column %d type mismatch in %s", nColumnIndex, readSet.GetRecordInfo()->GetName() ); if( eRecordType != eResultType ) return k_EResultInvalidParam; } CSQLRecord sqlRecord = GetResultRecord( 0, 0 ); FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex ) { uint8 *pubData; uint32 cubData; DbgVerify( sqlRecord.BGetColumnData( nColumnIndex, &pubData, (int*)&cubData ) ); DbgVerify( pRecord->BSetField( readSet.GetColumn( nColumnIndex), pubData, cubData ) ); } return k_EResultOK; } //----------------------------------------------------------------------------- // Purpose: Updates a record in the DB // Input: record - data source for columns to match against (whereColumns) and // columns to assign (updateColumns) // whereColumns - columns to match against // updateColumns - columns to update // Output: true if successful, false otherwise //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingUpdateRecord( const CRecordBase & record, const CColumnSet & whereColumns, const CColumnSet & updateColumns, const CSQLOutputParams *pOptionalOutputParams /* = NULL */ ) { return BYieldingUpdateRecords( record, whereColumns, record, updateColumns, pOptionalOutputParams ); } //----------------------------------------------------------------------------- // Purpose: //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingUpdateRecords( const CRecordBase & whereRecord, const CColumnSet & whereColumns, const CRecordBase & updateRecord, const CColumnSet & updateColumns, const CSQLOutputParams *pOptionalOutputParams /* = NULL */ ) { ClearParams(); Assert( whereColumns.GetRecordInfo() == updateColumns.GetRecordInfo() ); if ( whereColumns.GetRecordInfo() != updateColumns.GetRecordInfo() ) return false; Assert( whereColumns.GetRecordInfo() == whereRecord.GetPSchema()->GetRecordInfo() ); if ( whereColumns.GetRecordInfo() != whereRecord.GetPSchema()->GetRecordInfo() ) return false; Assert( whereColumns.GetRecordInfo() == updateRecord.GetPSchema()->GetRecordInfo() ); if ( whereColumns.GetRecordInfo() != updateRecord.GetPSchema()->GetRecordInfo() ) return false; AssertMsg( !updateColumns.IsEmpty(), "Someone is calling BYieldingUpdateRecord with no columns to update." ); if ( updateColumns.IsEmpty() ) return false; // add the columns we're updating as bound params TSQLCmdStr sStatement; BuildUpdateStatementText( &sStatement, updateColumns ); AddRecordParameters( updateRecord, updateColumns ); // did the users specify an OUTPUT block? if ( pOptionalOutputParams ) { TSQLCmdStr sOutput; BuildOutputClauseText( &sOutput, pOptionalOutputParams->GetColumnSet() ); sStatement.Append( sOutput ); AddRecordParameters( pOptionalOutputParams->GetRecord(), pOptionalOutputParams->GetColumnSet() ); } if ( !whereColumns.IsEmpty() ) { sStatement.Append( " WHERE " ); AppendWhereClauseText( &sStatement, whereColumns ); // add the columns we're querying on as bound params AddRecordParameters( whereRecord, whereColumns ); } return BYieldingExecute( sStatement, sStatement ); } //----------------------------------------------------------------------------- // Purpose: Deletes this record's row in the table // Input: record - record to delete // whereColumns - columns to use when searching for this record //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingDeleteRecords( const CRecordBase & record, const CColumnSet & whereColumns ) { Assert( whereColumns.GetRecordInfo() == record.GetPSchema()->GetRecordInfo() ); if ( whereColumns.GetRecordInfo() != record.GetPSchema()->GetRecordInfo() ) return false; ClearParams(); AddRecordParameters( record, whereColumns ); TSQLCmdStr sStatement; BuildDeleteStatementText( &sStatement, record.GetPRecordInfo() ); sStatement.Append( " WHERE " ); AppendWhereClauseText( &sStatement, whereColumns ); uint32 unRowsAffected; if( !BYieldingExecute( sStatement, sStatement, &unRowsAffected ) ) return false; return unRowsAffected > 0 || BInTransaction(); } //-------------------------------------------------------------------------------------------------------------------------------- // CSQLUpdateOrInsert //-------------------------------------------------------------------------------------------------------------------------------- CSQLUpdateOrInsert::CSQLUpdateOrInsert( const char* pszName, int nTable, const CColumnSet & whereColumns, const CColumnSet & updateColumns, const char* pszWhereClause, const char* pszUpdateClause ) { const CRecordInfo* pRecordInfo = GSchemaFull().GetSchema( nTable ).GetRecordInfo(); //how many columns do we have const int nNumColumns = pRecordInfo->GetNumColumns(); TSQLCmdStr sStatement; sStatement = "MERGE INTO "; sStatement.Append( GSchemaFull().GetDefaultSchemaNameForCatalog( pRecordInfo->GetESchemaCatalog() ) ); sStatement.Append( '.' ); sStatement.Append( pRecordInfo->GetName() ); sStatement.Append( " WITH(HOLDLOCK) AS D USING(VALUES(" ); sStatement.AppendFormat( "%.*s", GetInsertArgStringChars( nNumColumns ), GetInsertArgString() ); sStatement.Append( "))AS S(" ); //add each column that we are adding the values for, along with the parameter from the structure for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) { const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); if( nCurrColumn != 0 ) sStatement.Append( ',' ); sStatement.Append( colInfo.GetName() ); } //our where clause sStatement.Append( ")ON " ); if( pszWhereClause ) { sStatement.Append( pszWhereClause ); } else { FOR_EACH_COLUMN_IN_SET( whereColumns, nCurrColumn ) { const char* pszColName = pRecordInfo->GetColumnInfo( whereColumns.GetColumn( nCurrColumn ) ).GetName(); if( nCurrColumn > 0 ) sStatement.Append( " AND " ); sStatement.AppendFormat( "D.%s=S.%s", pszColName, pszColName ); } } //our update clause (if they have provided fields that they want to update) if( pszUpdateClause || !updateColumns.IsEmpty() ) { sStatement.Append( " WHEN MATCHED THEN UPDATE SET " ); if( pszUpdateClause ) { sStatement.Append( pszUpdateClause ); } else { FOR_EACH_COLUMN_IN_SET( updateColumns, nCurrColumn ) { const char* pszColName = pRecordInfo->GetColumnInfo( updateColumns.GetColumn( nCurrColumn ) ).GetName(); if( nCurrColumn > 0 ) sStatement.Append( ',' ); sStatement.AppendFormat( "%s=S.%s", pszColName, pszColName ); } } } //our insert clause sStatement.Append( " WHEN NOT MATCHED THEN INSERT(" ); bool bFirstColumn = true; for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) { const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); if( !colInfo.BIsInsertable() ) continue; if( !bFirstColumn ) sStatement.Append( ',' ); bFirstColumn = false; sStatement.Append( colInfo.GetName() ); } sStatement.Append( ")VALUES(" ); bFirstColumn = true; for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) { const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); if( !colInfo.BIsInsertable() ) continue; if( !bFirstColumn ) sStatement.Append( ',' ); bFirstColumn = false; sStatement.AppendFormat( "S.%s", colInfo.GetName() ); } sStatement.Append( ");" ); //save our results so we can execute it in the future m_nTable = nTable; m_sName = pszName; m_sQuery = sStatement; } bool CSQLUpdateOrInsert::BYieldingExecute( CSQLAccess& sqlAccess, const CRecordBase& record, uint32 *out_punRowsAffected /* = NULL */ ) const { AssertMsg2( record.GetITable() == m_nTable, "Error: Merge was compiled for table %s, but was attempted to be executed against %s", GSchemaFull().GetSchema( m_nTable ).GetRecordInfo()->GetName(), record.GetPRecordInfo()->GetName() ); const CRecordInfo* pRecordInfo = record.GetPRecordInfo(); //how many columns do we have const int nNumColumns = pRecordInfo->GetNumColumns(); sqlAccess.ClearParams(); for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) { const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); uint8 *pubData; uint32 cubData; DbgVerify( record.BGetField( nCurrColumn, &pubData, &cubData ) ); sqlAccess.AddBindParamRaw( colInfo.GetType(), pubData, cubData ); } return sqlAccess.BYieldingExecute( m_sName, m_sQuery, out_punRowsAffected ); } //----------------------------------------------------------------------------- // Purpose: Adds bind parameters to the list based on a set of fields in a record // Input: record - record to insert // columnSet - The set of columns to add as params //----------------------------------------------------------------------------- void CSQLAccess::AddRecordParameters( const CRecordBase &record, const CColumnSet & columnSet ) { Assert( record.GetPSchema()->GetRecordInfo() == columnSet.GetRecordInfo() ); if ( record.GetPSchema()->GetRecordInfo() != columnSet.GetRecordInfo() ) return; FOR_EACH_COLUMN_IN_SET( columnSet, nColumnIndex ) { const CColumnInfo &columnInfo = columnSet.GetColumnInfo( nColumnIndex ); uint8 *pubData; uint32 cubData; DbgVerify( record.BGetField( columnSet.GetColumn( nColumnIndex ), &pubData, &cubData ) ); EGCSQLType eType = columnInfo.GetType(); CurrentQuery()->AddBindParamRaw( eType, pubData, cubData ); } } //----------------------------------------------------------------------------- // Purpose: Deletes all records from a table // Input: iTable - table to wipe // Output: true if the operation was successful // Note: PERFORMANCE WARNING: this is slow on big tables, not intended for use // in production //----------------------------------------------------------------------------- bool CSQLAccess::BYieldingWipeTable( int iTable ) { // make a wipe operation CRecordInfo *pRecordInfo = GSchemaFull().GetSchema( iTable ).GetRecordInfo(); CUtlString buf; buf.Format( "DELETE FROM %s", pRecordInfo->GetName() ); return BYieldingExecute( buf.String(), buf.String() ); } //----------------------------------------------------------------------------- // Purpose: Returns the current query to add stuff to, creating it if there isn't // already a current query //----------------------------------------------------------------------------- CGCSQLQuery *CSQLAccess::CurrentQuery() { if( m_pCurrentQuery ) return m_pCurrentQuery; m_pCurrentQuery = new CGCSQLQuery(); return m_pCurrentQuery; } } // namespace GCSDK