From c918f447b2434bd29dcb0edc9ccee22fc4c6ea98 Mon Sep 17 00:00:00 2001 From: WorkClub Automation Date: Thu, 5 Mar 2026 20:43:03 +0100 Subject: [PATCH] fix(backend): add TenantDbTransactionInterceptor for RLS with explicit transactions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Option D: wraps auto-commit reads in explicit transactions with SET LOCAL. Handles transaction lifecycle (create→SET LOCAL→execute→commit/dispose). Uses IDbTransactionInterceptor for EF-managed SaveChanges transactions. Critical fix for PostgreSQL RLS requiring transaction-scoped context. --- backend/WorkClub.Api/Program.cs | 5 +- .../TenantDbTransactionInterceptor.cs | 304 ++++++++++++++++++ 2 files changed, 307 insertions(+), 2 deletions(-) create mode 100644 backend/WorkClub.Infrastructure/Data/Interceptors/TenantDbTransactionInterceptor.cs diff --git a/backend/WorkClub.Api/Program.cs b/backend/WorkClub.Api/Program.cs index 643d04f..be8f050 100644 --- a/backend/WorkClub.Api/Program.cs +++ b/backend/WorkClub.Api/Program.cs @@ -27,7 +27,7 @@ builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); -builder.Services.AddScoped(); +builder.Services.AddScoped(); builder.Services.AddSingleton(); builder.Services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme) @@ -36,6 +36,7 @@ builder.Services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme) options.Authority = builder.Configuration["Keycloak:Authority"]; options.Audience = builder.Configuration["Keycloak:Audience"]; options.RequireHttpsMetadata = false; + options.MapInboundClaims = false; options.TokenValidationParameters = new Microsoft.IdentityModel.Tokens.TokenValidationParameters { ValidateIssuer = false, // Disabled for local dev - external clients use localhost:8080, internal use keycloak:8080 @@ -56,7 +57,7 @@ builder.Services.AddAuthorizationBuilder() builder.Services.AddDbContext((sp, options) => options.UseNpgsql(builder.Configuration.GetConnectionString("DefaultConnection")) .AddInterceptors( - sp.GetRequiredService(), + sp.GetRequiredService(), sp.GetRequiredService())); var connectionString = builder.Configuration.GetConnectionString("DefaultConnection"); diff --git a/backend/WorkClub.Infrastructure/Data/Interceptors/TenantDbTransactionInterceptor.cs b/backend/WorkClub.Infrastructure/Data/Interceptors/TenantDbTransactionInterceptor.cs new file mode 100644 index 0000000..a55f037 --- /dev/null +++ b/backend/WorkClub.Infrastructure/Data/Interceptors/TenantDbTransactionInterceptor.cs @@ -0,0 +1,304 @@ +using System.Data.Common; +using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Http; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.Extensions.Logging; +using Npgsql; + +namespace WorkClub.Infrastructure.Data.Interceptors; + +/// +/// Sets PostgreSQL RLS tenant context using SET LOCAL in explicit transactions. +/// For auto-commit reads: wraps in explicit transaction, applies SET LOCAL, commits on reader dispose. +/// For transactional writes: applies SET LOCAL once when transaction starts. +/// +public class TenantDbTransactionInterceptor : DbCommandInterceptor, IDbTransactionInterceptor +{ + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly ILogger _logger; + + // Track transactions we created (so we know to commit/dispose them) + private readonly ConditionalWeakTable _ownedTxByCommand = new(); + private readonly ConditionalWeakTable _ownedTxByReader = new(); + + public TenantDbTransactionInterceptor( + IHttpContextAccessor httpContextAccessor, + ILogger logger) + { + _httpContextAccessor = httpContextAccessor; + _logger = logger; + } + + // === READER COMMANDS (SELECT queries) === + + public override InterceptionResult ReaderExecuting( + DbCommand command, CommandEventData eventData, InterceptionResult result) + { + EnsureTransactionAndTenant(command); + return base.ReaderExecuting(command, eventData, result); + } + + public override ValueTask> ReaderExecutingAsync( + DbCommand command, CommandEventData eventData, InterceptionResult result, + CancellationToken cancellationToken = default) + { + EnsureTransactionAndTenant(command); + return base.ReaderExecutingAsync(command, eventData, result, cancellationToken); + } + + // After reader executes, transfer tx ownership from command to reader + public override DbDataReader ReaderExecuted( + DbCommand command, CommandExecutedEventData eventData, DbDataReader result) + { + if (_ownedTxByCommand.TryGetValue(command, out var tx)) + { + _ownedTxByCommand.Remove(command); + _ownedTxByReader.AddOrUpdate(result, tx); + } + return base.ReaderExecuted(command, eventData, result); + } + + public override ValueTask ReaderExecutedAsync( + DbCommand command, CommandExecutedEventData eventData, DbDataReader result, + CancellationToken cancellationToken = default) + { + if (_ownedTxByCommand.TryGetValue(command, out var tx)) + { + _ownedTxByCommand.Remove(command); + _ownedTxByReader.AddOrUpdate(result, tx); + } + return base.ReaderExecutedAsync(command, eventData, result, cancellationToken); + } + + // When reader is disposed, commit and dispose the owned transaction + public override InterceptionResult DataReaderDisposing( + DbCommand command, DataReaderDisposingEventData eventData, InterceptionResult result) + { + if (_ownedTxByReader.TryGetValue(eventData.DataReader, out var tx)) + { + _ownedTxByReader.Remove(eventData.DataReader); + try + { + tx.Commit(); + _logger.LogDebug("Committed owned transaction for reader disposal"); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to commit owned transaction on reader disposal"); + try { tx.Rollback(); } catch { /* best-effort */ } + } + finally + { + tx.Dispose(); + } + } + return base.DataReaderDisposing(command, eventData, result); + } + + // === SCALAR COMMANDS === + + public override InterceptionResult ScalarExecuting( + DbCommand command, CommandEventData eventData, InterceptionResult result) + { + EnsureTransactionAndTenant(command); + return base.ScalarExecuting(command, eventData, result); + } + + public override ValueTask> ScalarExecutingAsync( + DbCommand command, CommandEventData eventData, InterceptionResult result, + CancellationToken cancellationToken = default) + { + EnsureTransactionAndTenant(command); + return base.ScalarExecutingAsync(command, eventData, result, cancellationToken); + } + + // Commit owned transaction immediately after scalar execution + public override object? ScalarExecuted( + DbCommand command, CommandExecutedEventData eventData, object? result) + { + CommitOwnedTransaction(command); + return base.ScalarExecuted(command, eventData, result); + } + + public override ValueTask ScalarExecutedAsync( + DbCommand command, CommandExecutedEventData eventData, object? result, + CancellationToken cancellationToken = default) + { + CommitOwnedTransaction(command); + return base.ScalarExecutedAsync(command, eventData, result, cancellationToken); + } + + // === NON-QUERY COMMANDS === + + public override InterceptionResult NonQueryExecuting( + DbCommand command, CommandEventData eventData, InterceptionResult result) + { + EnsureTransactionAndTenant(command); + return base.NonQueryExecuting(command, eventData, result); + } + + public override ValueTask> NonQueryExecutingAsync( + DbCommand command, CommandEventData eventData, InterceptionResult result, + CancellationToken cancellationToken = default) + { + EnsureTransactionAndTenant(command); + return base.NonQueryExecutingAsync(command, eventData, result, cancellationToken); + } + + public override int NonQueryExecuted( + DbCommand command, CommandExecutedEventData eventData, int result) + { + CommitOwnedTransaction(command); + return base.NonQueryExecuted(command, eventData, result); + } + + public override ValueTask NonQueryExecutedAsync( + DbCommand command, CommandExecutedEventData eventData, int result, + CancellationToken cancellationToken = default) + { + CommitOwnedTransaction(command); + return base.NonQueryExecutedAsync(command, eventData, result, cancellationToken); + } + + // === ERROR HANDLING === + + public override void CommandFailed(DbCommand command, CommandErrorEventData eventData) + { + RollbackOwnedTransaction(command); + _logger.LogError(eventData.Exception, "Command failed: {Sql}", + command.CommandText[..Math.Min(200, command.CommandText.Length)]); + base.CommandFailed(command, eventData); + } + + public override Task CommandFailedAsync(DbCommand command, CommandErrorEventData eventData, + CancellationToken cancellationToken = default) + { + RollbackOwnedTransaction(command); + _logger.LogError(eventData.Exception, "Command failed: {Sql}", + command.CommandText[..Math.Min(200, command.CommandText.Length)]); + return base.CommandFailedAsync(command, eventData, cancellationToken); + } + + // === PRIVATE HELPERS === + + private string? GetValidatedTenantId() + { + var tenantId = _httpContextAccessor.HttpContext?.Items["TenantId"] as string; + if (string.IsNullOrWhiteSpace(tenantId)) return null; + if (!Guid.TryParse(tenantId, out _)) + { + _logger.LogWarning("Invalid tenant ID format: {TenantId}", tenantId); + return null; + } + return tenantId; + } + + [ThreadStatic] + private static bool _isApplyingSetLocal; + + private void EnsureTransactionAndTenant(DbCommand command) + { + if (_isApplyingSetLocal) return; // Prevent recursion if ExecuteNonQuery calls interceptor + + // If the command already has a transaction, we assume TransactionStarted already set the tenant + if (command.Transaction != null) return; + + var tenantId = GetValidatedTenantId(); + if (tenantId == null) return; + + var conn = command.Connection; + if (conn is not NpgsqlConnection) return; + + // Auto-commit command: Create an explicit transaction + var tx = conn.BeginTransaction(); + command.Transaction = tx; + _ownedTxByCommand.AddOrUpdate(command, tx); + _logger.LogDebug("Created owned transaction for auto-commit command"); + + ApplySetLocalToTransaction(conn, tx, tenantId); + } + + private void ApplySetLocalToTransaction(DbConnection conn, DbTransaction tx, string tenantId) + { + try { + _isApplyingSetLocal = true; + using var setCmd = (conn as NpgsqlConnection)!.CreateCommand(); + setCmd.Transaction = tx as NpgsqlTransaction; + setCmd.CommandText = $"SET LOCAL app.current_tenant_id = '{tenantId}'"; + setCmd.ExecuteNonQuery(); + + _logger.LogDebug("Applied SET LOCAL for tenant {TenantId} on tx {TxHashCode}", tenantId, tx.GetHashCode()); + } catch (Exception ex) { + _logger.LogError(ex, "Failed to apply SET LOCAL"); + } finally { + _isApplyingSetLocal = false; + } + } + + private void CommitOwnedTransaction(DbCommand command) + { + if (_ownedTxByCommand.TryGetValue(command, out var tx)) + { + _ownedTxByCommand.Remove(command); + try { tx.Commit(); _logger.LogDebug("Committed owned transaction for scalar/nonquery"); } + catch { try { tx.Rollback(); } catch { } } + finally { tx.Dispose(); } + } + } + + private void RollbackOwnedTransaction(DbCommand command) + { + if (_ownedTxByCommand.TryGetValue(command, out var tx)) + { + _ownedTxByCommand.Remove(command); + try { tx.Rollback(); _logger.LogDebug("Rolled back owned transaction on failure"); } + catch { /* best-effort */ } + finally { tx.Dispose(); } + } + } + + // === TRANSACTION INTERCEPTOR (for EF-managed transactions like SaveChanges) === + + #region IDbTransactionInterceptor implementation + + public DbTransaction TransactionStarted(DbConnection connection, TransactionEndEventData eventData, DbTransaction result) + { + var tenantId = GetValidatedTenantId(); + if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId); + return result; + } + + public async ValueTask TransactionStartedAsync(DbConnection connection, TransactionEndEventData eventData, DbTransaction result, CancellationToken cancellationToken = default) + { + var tenantId = GetValidatedTenantId(); + if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId); + return result; + } + + public InterceptionResult TransactionStarting(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult result) => result; + public ValueTask> TransactionStartingAsync(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); + public InterceptionResult TransactionCommitting(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; + public ValueTask TransactionCommittingAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); + public void TransactionCommitted(DbTransaction transaction, TransactionEndEventData eventData) { } + public Task TransactionCommittedAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; + public InterceptionResult TransactionRollingBack(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; + public ValueTask TransactionRollingBackAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); + public void TransactionRolledBack(DbTransaction transaction, TransactionEndEventData eventData) { } + public Task TransactionRolledBackAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; + public DbTransaction CreatedSavepoint(DbTransaction transaction, TransactionEventData eventData) => transaction; + public ValueTask CreatedSavepointAsync(DbTransaction transaction, TransactionEventData eventData, CancellationToken cancellationToken = default) => new(transaction); + public InterceptionResult CreatingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; + public ValueTask CreatingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); + public InterceptionResult ReleasingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; + public ValueTask ReleasingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); + public void ReleasedSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { } + public Task ReleasedSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; + public InterceptionResult RollingBackToSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result; + public ValueTask RollingBackToSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result); + public void RolledBackToSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { } + public Task RolledBackToSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask; + public DbTransaction TransactionUsed(DbConnection connection, TransactionEventData eventData, DbTransaction result) => result; + public ValueTask TransactionUsedAsync(DbConnection connection, TransactionEventData eventData, DbTransaction result, CancellationToken cancellationToken = default) => new(result); + + #endregion +}