fix(backend): add TenantDbTransactionInterceptor for RLS with explicit transactions
Implements Option D: wraps auto-commit reads in explicit transactions with SET LOCAL. Handles transaction lifecycle (create→SET LOCAL→execute→commit/dispose). Uses IDbTransactionInterceptor for EF-managed SaveChanges transactions. Critical fix for PostgreSQL RLS requiring transaction-scoped context.
This commit is contained in:
@@ -0,0 +1,304 @@
|
||||
using System.Data.Common;
|
||||
using System.Runtime.CompilerServices;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.EntityFrameworkCore.Diagnostics;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Npgsql;
|
||||
|
||||
namespace WorkClub.Infrastructure.Data.Interceptors;
|
||||
|
||||
/// <summary>
|
||||
/// Sets PostgreSQL RLS tenant context using SET LOCAL in explicit transactions.
|
||||
/// For auto-commit reads: wraps in explicit transaction, applies SET LOCAL, commits on reader dispose.
|
||||
/// For transactional writes: applies SET LOCAL once when transaction starts.
|
||||
/// </summary>
|
||||
public class TenantDbTransactionInterceptor : DbCommandInterceptor, IDbTransactionInterceptor
|
||||
{
|
||||
private readonly IHttpContextAccessor _httpContextAccessor;
|
||||
private readonly ILogger<TenantDbTransactionInterceptor> _logger;
|
||||
|
||||
// Track transactions we created (so we know to commit/dispose them)
|
||||
private readonly ConditionalWeakTable<DbCommand, DbTransaction> _ownedTxByCommand = new();
|
||||
private readonly ConditionalWeakTable<DbDataReader, DbTransaction> _ownedTxByReader = new();
|
||||
|
||||
public TenantDbTransactionInterceptor(
|
||||
IHttpContextAccessor httpContextAccessor,
|
||||
ILogger<TenantDbTransactionInterceptor> logger)
|
||||
{
|
||||
_httpContextAccessor = httpContextAccessor;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
// === READER COMMANDS (SELECT queries) ===
|
||||
|
||||
public override InterceptionResult<DbDataReader> ReaderExecuting(
|
||||
DbCommand command, CommandEventData eventData, InterceptionResult<DbDataReader> result)
|
||||
{
|
||||
EnsureTransactionAndTenant(command);
|
||||
return base.ReaderExecuting(command, eventData, result);
|
||||
}
|
||||
|
||||
public override ValueTask<InterceptionResult<DbDataReader>> ReaderExecutingAsync(
|
||||
DbCommand command, CommandEventData eventData, InterceptionResult<DbDataReader> result,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
EnsureTransactionAndTenant(command);
|
||||
return base.ReaderExecutingAsync(command, eventData, result, cancellationToken);
|
||||
}
|
||||
|
||||
// After reader executes, transfer tx ownership from command to reader
|
||||
public override DbDataReader ReaderExecuted(
|
||||
DbCommand command, CommandExecutedEventData eventData, DbDataReader result)
|
||||
{
|
||||
if (_ownedTxByCommand.TryGetValue(command, out var tx))
|
||||
{
|
||||
_ownedTxByCommand.Remove(command);
|
||||
_ownedTxByReader.AddOrUpdate(result, tx);
|
||||
}
|
||||
return base.ReaderExecuted(command, eventData, result);
|
||||
}
|
||||
|
||||
public override ValueTask<DbDataReader> ReaderExecutedAsync(
|
||||
DbCommand command, CommandExecutedEventData eventData, DbDataReader result,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_ownedTxByCommand.TryGetValue(command, out var tx))
|
||||
{
|
||||
_ownedTxByCommand.Remove(command);
|
||||
_ownedTxByReader.AddOrUpdate(result, tx);
|
||||
}
|
||||
return base.ReaderExecutedAsync(command, eventData, result, cancellationToken);
|
||||
}
|
||||
|
||||
// When reader is disposed, commit and dispose the owned transaction
|
||||
public override InterceptionResult DataReaderDisposing(
|
||||
DbCommand command, DataReaderDisposingEventData eventData, InterceptionResult result)
|
||||
{
|
||||
if (_ownedTxByReader.TryGetValue(eventData.DataReader, out var tx))
|
||||
{
|
||||
_ownedTxByReader.Remove(eventData.DataReader);
|
||||
try
|
||||
{
|
||||
tx.Commit();
|
||||
_logger.LogDebug("Committed owned transaction for reader disposal");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning(ex, "Failed to commit owned transaction on reader disposal");
|
||||
try { tx.Rollback(); } catch { /* best-effort */ }
|
||||
}
|
||||
finally
|
||||
{
|
||||
tx.Dispose();
|
||||
}
|
||||
}
|
||||
return base.DataReaderDisposing(command, eventData, result);
|
||||
}
|
||||
|
||||
// === SCALAR COMMANDS ===
|
||||
|
||||
public override InterceptionResult<object> ScalarExecuting(
|
||||
DbCommand command, CommandEventData eventData, InterceptionResult<object> result)
|
||||
{
|
||||
EnsureTransactionAndTenant(command);
|
||||
return base.ScalarExecuting(command, eventData, result);
|
||||
}
|
||||
|
||||
public override ValueTask<InterceptionResult<object>> ScalarExecutingAsync(
|
||||
DbCommand command, CommandEventData eventData, InterceptionResult<object> result,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
EnsureTransactionAndTenant(command);
|
||||
return base.ScalarExecutingAsync(command, eventData, result, cancellationToken);
|
||||
}
|
||||
|
||||
// Commit owned transaction immediately after scalar execution
|
||||
public override object? ScalarExecuted(
|
||||
DbCommand command, CommandExecutedEventData eventData, object? result)
|
||||
{
|
||||
CommitOwnedTransaction(command);
|
||||
return base.ScalarExecuted(command, eventData, result);
|
||||
}
|
||||
|
||||
public override ValueTask<object?> ScalarExecutedAsync(
|
||||
DbCommand command, CommandExecutedEventData eventData, object? result,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
CommitOwnedTransaction(command);
|
||||
return base.ScalarExecutedAsync(command, eventData, result, cancellationToken);
|
||||
}
|
||||
|
||||
// === NON-QUERY COMMANDS ===
|
||||
|
||||
public override InterceptionResult<int> NonQueryExecuting(
|
||||
DbCommand command, CommandEventData eventData, InterceptionResult<int> result)
|
||||
{
|
||||
EnsureTransactionAndTenant(command);
|
||||
return base.NonQueryExecuting(command, eventData, result);
|
||||
}
|
||||
|
||||
public override ValueTask<InterceptionResult<int>> NonQueryExecutingAsync(
|
||||
DbCommand command, CommandEventData eventData, InterceptionResult<int> result,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
EnsureTransactionAndTenant(command);
|
||||
return base.NonQueryExecutingAsync(command, eventData, result, cancellationToken);
|
||||
}
|
||||
|
||||
public override int NonQueryExecuted(
|
||||
DbCommand command, CommandExecutedEventData eventData, int result)
|
||||
{
|
||||
CommitOwnedTransaction(command);
|
||||
return base.NonQueryExecuted(command, eventData, result);
|
||||
}
|
||||
|
||||
public override ValueTask<int> NonQueryExecutedAsync(
|
||||
DbCommand command, CommandExecutedEventData eventData, int result,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
CommitOwnedTransaction(command);
|
||||
return base.NonQueryExecutedAsync(command, eventData, result, cancellationToken);
|
||||
}
|
||||
|
||||
// === ERROR HANDLING ===
|
||||
|
||||
public override void CommandFailed(DbCommand command, CommandErrorEventData eventData)
|
||||
{
|
||||
RollbackOwnedTransaction(command);
|
||||
_logger.LogError(eventData.Exception, "Command failed: {Sql}",
|
||||
command.CommandText[..Math.Min(200, command.CommandText.Length)]);
|
||||
base.CommandFailed(command, eventData);
|
||||
}
|
||||
|
||||
public override Task CommandFailedAsync(DbCommand command, CommandErrorEventData eventData,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
RollbackOwnedTransaction(command);
|
||||
_logger.LogError(eventData.Exception, "Command failed: {Sql}",
|
||||
command.CommandText[..Math.Min(200, command.CommandText.Length)]);
|
||||
return base.CommandFailedAsync(command, eventData, cancellationToken);
|
||||
}
|
||||
|
||||
// === PRIVATE HELPERS ===
|
||||
|
||||
private string? GetValidatedTenantId()
|
||||
{
|
||||
var tenantId = _httpContextAccessor.HttpContext?.Items["TenantId"] as string;
|
||||
if (string.IsNullOrWhiteSpace(tenantId)) return null;
|
||||
if (!Guid.TryParse(tenantId, out _))
|
||||
{
|
||||
_logger.LogWarning("Invalid tenant ID format: {TenantId}", tenantId);
|
||||
return null;
|
||||
}
|
||||
return tenantId;
|
||||
}
|
||||
|
||||
[ThreadStatic]
|
||||
private static bool _isApplyingSetLocal;
|
||||
|
||||
private void EnsureTransactionAndTenant(DbCommand command)
|
||||
{
|
||||
if (_isApplyingSetLocal) return; // Prevent recursion if ExecuteNonQuery calls interceptor
|
||||
|
||||
// If the command already has a transaction, we assume TransactionStarted already set the tenant
|
||||
if (command.Transaction != null) return;
|
||||
|
||||
var tenantId = GetValidatedTenantId();
|
||||
if (tenantId == null) return;
|
||||
|
||||
var conn = command.Connection;
|
||||
if (conn is not NpgsqlConnection) return;
|
||||
|
||||
// Auto-commit command: Create an explicit transaction
|
||||
var tx = conn.BeginTransaction();
|
||||
command.Transaction = tx;
|
||||
_ownedTxByCommand.AddOrUpdate(command, tx);
|
||||
_logger.LogDebug("Created owned transaction for auto-commit command");
|
||||
|
||||
ApplySetLocalToTransaction(conn, tx, tenantId);
|
||||
}
|
||||
|
||||
private void ApplySetLocalToTransaction(DbConnection conn, DbTransaction tx, string tenantId)
|
||||
{
|
||||
try {
|
||||
_isApplyingSetLocal = true;
|
||||
using var setCmd = (conn as NpgsqlConnection)!.CreateCommand();
|
||||
setCmd.Transaction = tx as NpgsqlTransaction;
|
||||
setCmd.CommandText = $"SET LOCAL app.current_tenant_id = '{tenantId}'";
|
||||
setCmd.ExecuteNonQuery();
|
||||
|
||||
_logger.LogDebug("Applied SET LOCAL for tenant {TenantId} on tx {TxHashCode}", tenantId, tx.GetHashCode());
|
||||
} catch (Exception ex) {
|
||||
_logger.LogError(ex, "Failed to apply SET LOCAL");
|
||||
} finally {
|
||||
_isApplyingSetLocal = false;
|
||||
}
|
||||
}
|
||||
|
||||
private void CommitOwnedTransaction(DbCommand command)
|
||||
{
|
||||
if (_ownedTxByCommand.TryGetValue(command, out var tx))
|
||||
{
|
||||
_ownedTxByCommand.Remove(command);
|
||||
try { tx.Commit(); _logger.LogDebug("Committed owned transaction for scalar/nonquery"); }
|
||||
catch { try { tx.Rollback(); } catch { } }
|
||||
finally { tx.Dispose(); }
|
||||
}
|
||||
}
|
||||
|
||||
private void RollbackOwnedTransaction(DbCommand command)
|
||||
{
|
||||
if (_ownedTxByCommand.TryGetValue(command, out var tx))
|
||||
{
|
||||
_ownedTxByCommand.Remove(command);
|
||||
try { tx.Rollback(); _logger.LogDebug("Rolled back owned transaction on failure"); }
|
||||
catch { /* best-effort */ }
|
||||
finally { tx.Dispose(); }
|
||||
}
|
||||
}
|
||||
|
||||
// === TRANSACTION INTERCEPTOR (for EF-managed transactions like SaveChanges) ===
|
||||
|
||||
#region IDbTransactionInterceptor implementation
|
||||
|
||||
public DbTransaction TransactionStarted(DbConnection connection, TransactionEndEventData eventData, DbTransaction result)
|
||||
{
|
||||
var tenantId = GetValidatedTenantId();
|
||||
if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId);
|
||||
return result;
|
||||
}
|
||||
|
||||
public async ValueTask<DbTransaction> TransactionStartedAsync(DbConnection connection, TransactionEndEventData eventData, DbTransaction result, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var tenantId = GetValidatedTenantId();
|
||||
if (tenantId != null) ApplySetLocalToTransaction(connection, result, tenantId);
|
||||
return result;
|
||||
}
|
||||
|
||||
public InterceptionResult<DbTransaction> TransactionStarting(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult<DbTransaction> result) => result;
|
||||
public ValueTask<InterceptionResult<DbTransaction>> TransactionStartingAsync(DbConnection connection, TransactionStartingEventData eventData, InterceptionResult<DbTransaction> result, CancellationToken cancellationToken = default) => new(result);
|
||||
public InterceptionResult TransactionCommitting(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
|
||||
public ValueTask<InterceptionResult> TransactionCommittingAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
|
||||
public void TransactionCommitted(DbTransaction transaction, TransactionEndEventData eventData) { }
|
||||
public Task TransactionCommittedAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
|
||||
public InterceptionResult TransactionRollingBack(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
|
||||
public ValueTask<InterceptionResult> TransactionRollingBackAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
|
||||
public void TransactionRolledBack(DbTransaction transaction, TransactionEndEventData eventData) { }
|
||||
public Task TransactionRolledBackAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
|
||||
public DbTransaction CreatedSavepoint(DbTransaction transaction, TransactionEventData eventData) => transaction;
|
||||
public ValueTask<DbTransaction> CreatedSavepointAsync(DbTransaction transaction, TransactionEventData eventData, CancellationToken cancellationToken = default) => new(transaction);
|
||||
public InterceptionResult CreatingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
|
||||
public ValueTask<InterceptionResult> CreatingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
|
||||
public InterceptionResult ReleasingSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
|
||||
public ValueTask<InterceptionResult> ReleasingSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
|
||||
public void ReleasedSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { }
|
||||
public Task ReleasedSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
|
||||
public InterceptionResult RollingBackToSavepoint(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result) => result;
|
||||
public ValueTask<InterceptionResult> RollingBackToSavepointAsync(DbTransaction transaction, TransactionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default) => new(result);
|
||||
public void RolledBackToSavepoint(DbTransaction transaction, TransactionEndEventData eventData) { }
|
||||
public Task RolledBackToSavepointAsync(DbTransaction transaction, TransactionEndEventData eventData, CancellationToken cancellationToken = default) => Task.CompletedTask;
|
||||
public DbTransaction TransactionUsed(DbConnection connection, TransactionEventData eventData, DbTransaction result) => result;
|
||||
public ValueTask<DbTransaction> TransactionUsedAsync(DbConnection connection, TransactionEventData eventData, DbTransaction result, CancellationToken cancellationToken = default) => new(result);
|
||||
|
||||
#endregion
|
||||
}
|
||||
Reference in New Issue
Block a user