fix(backend): add TenantDbTransactionInterceptor for RLS with explicit transactions

Implements Option D: wraps auto-commit reads in explicit transactions with SET LOCAL.
Handles transaction lifecycle (create→SET LOCAL→execute→commit/dispose).
Uses IDbTransactionInterceptor for EF-managed SaveChanges transactions.
Critical fix for PostgreSQL RLS requiring transaction-scoped context.
This commit is contained in:
WorkClub Automation
2026-03-05 20:43:03 +01:00
parent 5fb148a9eb
commit c918f447b2
2 changed files with 307 additions and 2 deletions

View File

@@ -0,0 +1,304 @@
using System.Data.Common;
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Http;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.Extensions.Logging;
using Npgsql;
namespace WorkClub.Infrastructure.Data.Interceptors;
/// <summary>
/// Sets PostgreSQL RLS tenant context using SET LOCAL in explicit transactions.
/// For auto-commit reads: wraps in explicit transaction, applies SET LOCAL, commits on reader dispose.
/// For transactional writes: applies SET LOCAL once when transaction starts.
/// </summary>
public class TenantDbTransactionInterceptor : DbCommandInterceptor, IDbTransactionInterceptor
{
private readonly IHttpContextAccessor _httpContextAccessor;
private readonly ILogger<TenantDbTransactionInterceptor> _logger;
// Track transactions we created (so we know to commit/dispose them)
private readonly ConditionalWeakTable<DbCommand, DbTransaction> _ownedTxByCommand = new();
private readonly ConditionalWeakTable<DbDataReader, DbTransaction> _ownedTxByReader = new();
public TenantDbTransactionInterceptor(
IHttpContextAccessor httpContextAccessor,
ILogger<TenantDbTransactionInterceptor> logger)
{
_httpContextAccessor = httpContextAccessor;
_logger = logger;
}
// === READER COMMANDS (SELECT queries) ===
public override InterceptionResult<DbDataReader> ReaderExecuting(
DbCommand command, CommandEventData eventData, InterceptionResult<DbDataReader> result)
{
EnsureTransactionAndTenant(command);
return base.ReaderExecuting(command, eventData, result);
}
public override ValueTask<InterceptionResult<DbDataReader>> ReaderExecutingAsync(
DbCommand command, CommandEventData eventData, InterceptionResult<DbDataReader> result,
CancellationToken cancellationToken = default)
{
EnsureTransactionAndTenant(command);
return base.ReaderExecutingAsync(command, eventData, result, cancellationToken);
}
// After reader executes, transfer tx ownership from command to reader
public override DbDataReader ReaderExecuted(
DbCommand command, CommandExecutedEventData eventData, DbDataReader result)
{
if (_ownedTxByCommand.TryGetValue(command, out var tx))
{
_ownedTxByCommand.Remove(command);
_ownedTxByReader.AddOrUpdate(result, tx);
}
return base.ReaderExecuted(command, eventData, result);
}
public override ValueTask<DbDataReader> ReaderExecutedAsync(
DbCommand command, CommandExecutedEventData eventData, DbDataReader result,
CancellationToken cancellationToken = default)
{
if (_ownedTxByCommand.TryGetValue(command, out var tx))
{
_ownedTxByCommand.Remove(command);
_ownedTxByReader.AddOrUpdate(result, tx);
}
return base.ReaderExecutedAsync(command, eventData, result, cancellationToken);
}
// When reader is disposed, commit and dispose the owned transaction
public override InterceptionResult DataReaderDisposing(
DbCommand command, DataReaderDisposingEventData eventData, InterceptionResult result)
{
if (_ownedTxByReader.TryGetValue(eventData.DataReader, out var tx))
{
_ownedTxByReader.Remove(eventData.DataReader);
try
{
tx.Commit();
_logger.LogDebug("Committed owned transaction for reader disposal");
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to commit owned transaction on reader disposal");
try { tx.Rollback(); } catch { /* best-effort */ }
}
finally
{
tx.Dispose();
}
}
return base.DataReaderDisposing(command, eventData, result);
}
// === SCALAR COMMANDS ===
public override InterceptionResult<object> ScalarExecuting(
DbCommand command, CommandEventData eventData, InterceptionResult<object> result)
{
EnsureTransactionAndTenant(command);
return base.ScalarExecuting(command, eventData, result);
}
public override ValueTask<InterceptionResult<object>> ScalarExecutingAsync(
DbCommand command, CommandEventData eventData, InterceptionResult<object> result,
CancellationToken cancellationToken = default)
{
EnsureTransactionAndTenant(command);
return base.ScalarExecutingAsync(command, eventData, result, cancellationToken);
}
// Commit owned transaction immediately after scalar execution
public override object? ScalarExecuted(
DbCommand command, CommandExecutedEventData eventData, object? result)
{
CommitOwnedTransaction(command);
return base.ScalarExecuted(command, eventData, result);
}
public override ValueTask<object?> ScalarExecutedAsync(
DbCommand command, CommandExecutedEventData eventData, object? result,
CancellationToken cancellationToken = default)
{
CommitOwnedTransaction(command);
return base.ScalarExecutedAsync(command, eventData, result, cancellationToken);
}
// === NON-QUERY COMMANDS ===
public override InterceptionResult<int> NonQueryExecuting(
DbCommand command, CommandEventData eventData, InterceptionResult<int> result)
{
EnsureTransactionAndTenant(command);
return base.NonQueryExecuting(command, eventData, result);
}
public override ValueTask<InterceptionResult<int>> NonQueryExecutingAsync(
DbCommand command, CommandEventData eventData, InterceptionResult<int> result,
CancellationToken cancellationToken = default)
{
EnsureTransactionAndTenant(command);
return base.NonQueryExecutingAsync(command, eventData, result, cancellationToken);
}
public override int NonQueryExecuted(
DbCommand command, CommandExecutedEventData eventData, int result)
{
CommitOwnedTransaction(command);
return base.NonQueryExecuted(command, eventData, result);
}
public override ValueTask<int> NonQueryExecutedAsync(
DbCommand command, CommandExecutedEventData eventData, int result,
CancellationToken cancellationToken = default)
{
CommitOwnedTransaction(command);
return base.NonQueryExecutedAsync(command, eventData, result, cancellationToken);
}
// === ERROR HANDLING ===
public override void CommandFailed(DbCommand command, CommandErrorEventData eventData)
{
RollbackOwnedTransaction(command);
_logger.LogError(eventData.Exception, "Command failed: {Sql}",
command.CommandText[..Math.Min(200, command.CommandText.Length)]);
base.CommandFailed(command, eventData);
}
public override Task CommandFailedAsync(DbCommand command, CommandErrorEventData eventData,
CancellationToken cancellationToken = default)
{
RollbackOwnedTransaction(command);
_logger.LogError(eventData.Exception, "Command failed: {Sql}",
command.CommandText[..Math.Min(200, command.CommandText.Length)]);
return base.CommandFailedAsync(command, eventData, cancellationToken);
}
// === PRIVATE HELPERS ===
private string? GetValidatedTenantId()
{
var tenantId = _httpContextAccessor.HttpContext?.Items["TenantId"] as string;
if (string.IsNullOrWhiteSpace(tenantId)) return null;
if (!Guid.TryParse(tenantId, out _))
{
_logger.LogWarning("Invalid tenant ID format: {TenantId}", tenantId);
return null;
}
return tenantId;
}
[ThreadStatic]
private static bool _isApplyingSetLocal;
private void EnsureTransactionAndTenant(DbCommand command)
{
if (_isApplyingSetLocal) return; // Prevent recursion if ExecuteNonQuery calls interceptor
// If the command already has a transaction, we assume TransactionStarted already set the tenant
if (command.Transaction != null) return;
var tenantId = GetValidatedTenantId();
if (tenantId == null) return;
var conn = command.Connection;
if (conn is not NpgsqlConnection) return;
// Auto-commit command: Create an explicit transaction
var tx = conn.BeginTransaction();
command.Transaction = tx;
_ownedTxByCommand.AddOrUpdate(command, tx);
_logger.LogDebug("Created owned transaction for auto-commit command");
ApplySetLocalToTransaction(conn, tx, tenantId);
}
private void ApplySetLocalToTransaction(DbConnection conn, DbTransaction tx, string tenantId)
{
try {
_isApplyingSetLocal = true;
using var setCmd = (conn as NpgsqlConnection)!.CreateCommand();
setCmd.Transaction = tx as NpgsqlTransaction;
setCmd.CommandText = $"SET LOCAL app.current_tenant_id = '{tenantId}'";
setCmd.ExecuteNonQuery();
_logger.LogDebug("Applied SET LOCAL for tenant {TenantId} on tx {TxHashCode}", tenantId, tx.GetHashCode());
} catch (Exception ex) {
_logger.LogError(ex, "Failed to apply SET LOCAL");
} finally {
_isApplyingSetLocal = false;
}
}
private void CommitOwnedTransaction(DbCommand command)
{
if (_ownedTxByCommand.TryGetValue(command, out var tx))
{
_ownedTxByCommand.Remove(command);
try { tx.Commit(); _logger.LogDebug("Committed owned transaction for scalar/nonquery"); }
catch { try { tx.Rollback(); } catch { } }
finally { tx.Dispose(); }
}
}
private void RollbackOwnedTransaction(DbCommand command)
{
if (_ownedTxByCommand.TryGetValue(command, out var tx))
{
_ownedTxByCommand.Remove(command);
try { tx.Rollback(); _logger.LogDebug("Rolled back owned transaction on failure"); }
catch { /* best-effort */ }
finally { tx.Dispose(); }
}
}
// === TRANSACTION INTERCEPTOR (for EF-managed transactions like SaveChanges) ===
#region IDbTransactionInterceptor implementation
public DbTransaction TransactionStarted(DbConnection connection, TransactionEndEventData eventData, DbTransaction result)
{
var tenantId = GetValidatedTenantId();
if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId);
return result;
}
public async ValueTask<DbTransaction> TransactionStartedAsync(DbConnection connection, TransactionEndEventData eventData, DbTransaction result, CancellationToken cancellationToken = default)
{
var tenantId = GetValidatedTenantId();
if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId);
return result;
}
public InterceptionResult<DbTransaction> TransactionStarting(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult<DbTransaction> result) => result;
public ValueTask<InterceptionResult<DbTransaction>> TransactionStartingAsync(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult<DbTransaction> result, CancellationToken cancellationToken = default) => new(result);
public InterceptionResult TransactionCommitting(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
public ValueTask<InterceptionResult> TransactionCommittingAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
public void TransactionCommitted(DbTransaction transaction, TransactionEndEventData eventData) { }
public Task TransactionCommittedAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
public InterceptionResult TransactionRollingBack(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
public ValueTask<InterceptionResult> TransactionRollingBackAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
public void TransactionRolledBack(DbTransaction transaction, TransactionEndEventData eventData) { }
public Task TransactionRolledBackAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
public DbTransaction CreatedSavepoint(DbTransaction transaction, TransactionEventData eventData) => transaction;
public ValueTask<DbTransaction> CreatedSavepointAsync(DbTransaction transaction, TransactionEventData eventData, CancellationToken cancellationToken = default) => new(transaction);
public InterceptionResult CreatingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
public ValueTask<InterceptionResult> CreatingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
public InterceptionResult ReleasingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
public ValueTask<InterceptionResult> ReleasingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
public void ReleasedSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { }
public Task ReleasedSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
public InterceptionResult RollingBackToSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
public ValueTask<InterceptionResult> RollingBackToSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
public void RolledBackToSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { }
public Task RolledBackToSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
public DbTransaction TransactionUsed(DbConnection connection, TransactionEventData eventData, DbTransaction result) => result;
public ValueTask<DbTransaction> TransactionUsedAsync(DbConnection connection, TransactionEventData eventData, DbTransaction result, CancellationToken cancellationToken = default) => new(result);
#endregion
}