using System.Data.Common; using System.Runtime.CompilerServices; using Microsoft.AspNetCore.Http; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.Extensions.Logging; using Npgsql; namespace WorkClub.Infrastructure.Data.Interceptors; /// /// Sets PostgreSQL RLS tenant context using SET LOCAL in explicit transactions. /// For auto-commit reads: wraps in explicit transaction, applies SET LOCAL, commits on reader dispose. /// For transactional writes: applies SET LOCAL once when transaction starts. /// public class TenantDbTransactionInterceptor : DbCommandInterceptor, IDbTransactionInterceptor { private readonly IHttpContextAccessor _httpContextAccessor; private readonly ILogger _logger; // Track transactions we created (so we know to commit/dispose them) private readonly ConditionalWeakTable _ownedTxByCommand = new(); private readonly ConditionalWeakTable _ownedTxByReader = new(); public TenantDbTransactionInterceptor( IHttpContextAccessor httpContextAccessor, ILogger logger) { _httpContextAccessor = httpContextAccessor; _logger = logger; } // === READER COMMANDS (SELECT queries) === public override InterceptionResult ReaderExecuting( DbCommand command, CommandEventData eventData, InterceptionResult result) { EnsureTransactionAndTenant(command); return base.ReaderExecuting(command, eventData, result); } public override ValueTask> ReaderExecutingAsync( DbCommand command, CommandEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) { EnsureTransactionAndTenant(command); return base.ReaderExecutingAsync(command, eventData, result, cancellationToken); } // After reader executes, transfer tx ownership from command to reader public override DbDataReader ReaderExecuted( DbCommand command, CommandExecutedEventData eventData, DbDataReader result) { if (_ownedTxByCommand.TryGetValue(command, out var tx)) { _ownedTxByCommand.Remove(command); _ownedTxByReader.AddOrUpdate(result, tx); } return base.ReaderExecuted(command, eventData, result); } public override ValueTask ReaderExecutedAsync( DbCommand command, CommandExecutedEventData eventData, DbDataReader result, CancellationToken cancellationToken = default) { if (_ownedTxByCommand.TryGetValue(command, out var tx)) { _ownedTxByCommand.Remove(command); _ownedTxByReader.AddOrUpdate(result, tx); } return base.ReaderExecutedAsync(command, eventData, result, cancellationToken); } // When reader is disposed, commit and dispose the owned transaction public override InterceptionResult DataReaderDisposing( DbCommand command, DataReaderDisposingEventData eventData, InterceptionResult result) { if (_ownedTxByReader.TryGetValue(eventData.DataReader, out var tx)) { _ownedTxByReader.Remove(eventData.DataReader); try { tx.Commit(); _logger.LogDebug("Committed owned transaction for reader disposal"); } catch (Exception ex) { _logger.LogWarning(ex, "Failed to commit owned transaction on reader disposal"); try { tx.Rollback(); } catch { /* best-effort */ } } finally { tx.Dispose(); } } return base.DataReaderDisposing(command, eventData, result); } // === SCALAR COMMANDS === public override InterceptionResult ScalarExecuting( DbCommand command, CommandEventData eventData, InterceptionResult result) { EnsureTransactionAndTenant(command); return base.ScalarExecuting(command, eventData, result); } public override ValueTask> ScalarExecutingAsync( DbCommand command, CommandEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) { EnsureTransactionAndTenant(command); return base.ScalarExecutingAsync(command, eventData, result, cancellationToken); } // Commit owned transaction immediately after scalar execution public override object? ScalarExecuted( DbCommand command, CommandExecutedEventData eventData, object? result) { CommitOwnedTransaction(command); return base.ScalarExecuted(command, eventData, result); } public override ValueTask ScalarExecutedAsync( DbCommand command, CommandExecutedEventData eventData, object? result, CancellationToken cancellationToken = default) { CommitOwnedTransaction(command); return base.ScalarExecutedAsync(command, eventData, result, cancellationToken); } // === NON-QUERY COMMANDS === public override InterceptionResult NonQueryExecuting( DbCommand command, CommandEventData eventData, InterceptionResult result) { EnsureTransactionAndTenant(command); return base.NonQueryExecuting(command, eventData, result); } public override ValueTask> NonQueryExecutingAsync( DbCommand command, CommandEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) { EnsureTransactionAndTenant(command); return base.NonQueryExecutingAsync(command, eventData, result, cancellationToken); } public override int NonQueryExecuted( DbCommand command, CommandExecutedEventData eventData, int result) { CommitOwnedTransaction(command); return base.NonQueryExecuted(command, eventData, result); } public override ValueTask NonQueryExecutedAsync( DbCommand command, CommandExecutedEventData eventData, int result, CancellationToken cancellationToken = default) { CommitOwnedTransaction(command); return base.NonQueryExecutedAsync(command, eventData, result, cancellationToken); } // === ERROR HANDLING === public override void CommandFailed(DbCommand command, CommandErrorEventData eventData) { RollbackOwnedTransaction(command); _logger.LogError(eventData.Exception, "Command failed: {Sql}", command.CommandText[..Math.Min(200, command.CommandText.Length)]); base.CommandFailed(command, eventData); } public override Task CommandFailedAsync(DbCommand command, CommandErrorEventData eventData, CancellationToken cancellationToken = default) { RollbackOwnedTransaction(command); _logger.LogError(eventData.Exception, "Command failed: {Sql}", command.CommandText[..Math.Min(200, command.CommandText.Length)]); return base.CommandFailedAsync(command, eventData, cancellationToken); } // === PRIVATE HELPERS === private string? GetValidatedTenantId() { var tenantId = _httpContextAccessor.HttpContext?.Items["TenantId"] as string; if (string.IsNullOrWhiteSpace(tenantId)) return null; if (!Guid.TryParse(tenantId, out _)) { _logger.LogWarning("Invalid tenant ID format: {TenantId}", tenantId); return null; } return tenantId; } [ThreadStatic] private static bool _isApplyingSetLocal; private void EnsureTransactionAndTenant(DbCommand command) { if (_isApplyingSetLocal) return; // Prevent recursion if ExecuteNonQuery calls interceptor // If the command already has a transaction, we assume TransactionStarted already set the tenant if (command.Transaction != null) return; var tenantId = GetValidatedTenantId(); if (tenantId == null) return; var conn = command.Connection; if (conn is not NpgsqlConnection) return; // Auto-commit command: Create an explicit transaction var tx = conn.BeginTransaction(); command.Transaction = tx; _ownedTxByCommand.AddOrUpdate(command, tx); _logger.LogDebug("Created owned transaction for auto-commit command"); ApplySetLocalToTransaction(conn, tx, tenantId); } private void ApplySetLocalToTransaction(DbConnection conn, DbTransaction tx, string tenantId) { try { _isApplyingSetLocal = true; using var setCmd = (conn as NpgsqlConnection)!.CreateCommand(); setCmd.Transaction = tx as NpgsqlTransaction; setCmd.CommandText = $"SET LOCAL app.current_tenant_id = '{tenantId}'"; setCmd.ExecuteNonQuery(); _logger.LogDebug("Applied SET LOCAL for tenant {TenantId} on tx {TxHashCode}", tenantId, tx.GetHashCode()); } catch (Exception ex) { _logger.LogError(ex, "Failed to apply SET LOCAL"); } finally { _isApplyingSetLocal = false; } } private void CommitOwnedTransaction(DbCommand command) { if (_ownedTxByCommand.TryGetValue(command, out var tx)) { _ownedTxByCommand.Remove(command); try { tx.Commit(); _logger.LogDebug("Committed owned transaction for scalar/nonquery"); } catch { try { tx.Rollback(); } catch { } } finally { tx.Dispose(); } } } private void RollbackOwnedTransaction(DbCommand command) { if (_ownedTxByCommand.TryGetValue(command, out var tx)) { _ownedTxByCommand.Remove(command); try { tx.Rollback(); _logger.LogDebug("Rolled back owned transaction on failure"); } catch { /* best-effort */ } finally { tx.Dispose(); } } } // === TRANSACTION INTERCEPTOR (for EF-managed transactions like SaveChanges) === #region IDbTransactionInterceptor implementation public DbTransaction TransactionStarted(DbConnection connection, TransactionEndEventData eventData, DbTransaction result) { var tenantId = GetValidatedTenantId(); if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId); return result; } public async ValueTask TransactionStartedAsync(DbConnection connection, TransactionEndEventData eventData, DbTransaction result, CancellationToken cancellationToken = default) { var tenantId = GetValidatedTenantId(); if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId); return result; } public InterceptionResult TransactionStarting(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult result) => result; public ValueTask> TransactionStartingAsync(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); public InterceptionResult TransactionCommitting(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; public ValueTask TransactionCommittingAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); public void TransactionCommitted(DbTransaction transaction, TransactionEndEventData eventData) { } public Task TransactionCommittedAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; public InterceptionResult TransactionRollingBack(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; public ValueTask TransactionRollingBackAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); public void TransactionRolledBack(DbTransaction transaction, TransactionEndEventData eventData) { } public Task TransactionRolledBackAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; public DbTransaction CreatedSavepoint(DbTransaction transaction, TransactionEventData eventData) => transaction; public ValueTask CreatedSavepointAsync(DbTransaction transaction, TransactionEventData eventData, CancellationToken cancellationToken = default) => new(transaction); public InterceptionResult CreatingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; public ValueTask CreatingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); public InterceptionResult ReleasingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; public ValueTask ReleasingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); public void ReleasedSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { } public Task ReleasedSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; public InterceptionResult RollingBackToSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; public ValueTask RollingBackToSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); public void RolledBackToSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { } public Task RolledBackToSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; public DbTransaction TransactionUsed(DbConnection connection, TransactionEventData eventData, DbTransaction result) => result; public ValueTask TransactionUsedAsync(DbConnection connection, TransactionEventData eventData, DbTransaction result, CancellationToken cancellationToken = default) => new(result); #endregion }