using Dapper; using Microsoft.EntityFrameworkCore; using Npgsql; using Testcontainers.PostgreSql; using WorkClub.Domain.Entities; using WorkClub.Infrastructure.Data; namespace WorkClub.Tests.Integration.Data; public class RlsTests : IAsyncLifetime { private PostgreSqlContainer? _container; private string? _connectionString; private string? _adminConnectionString; private string? _rlsUserConnectionString; public async Task InitializeAsync() { _container = new PostgreSqlBuilder() .WithImage("postgres:16-alpine") .WithDatabase("workclub") .WithUsername("app_user") .WithPassword("apppass") .Build(); await _container.StartAsync(); _adminConnectionString = _container.GetConnectionString(); await using var adminConn = new NpgsqlConnection(_adminConnectionString); await adminConn.OpenAsync(); await adminConn.ExecuteAsync(@" CREATE USER rls_test_user WITH PASSWORD 'rlspass'; GRANT CONNECT ON DATABASE workclub TO rls_test_user; "); var builder = new NpgsqlConnectionStringBuilder(_adminConnectionString) { Username = "rls_test_user", Password = "rlspass" }; _connectionString = builder.ConnectionString; _rlsUserConnectionString = _connectionString; } public async Task DisposeAsync() { if (_container != null) { await _container.DisposeAsync(); } } [Fact] public async Task RLS_BlocksAccess_WithoutTenantContext() { await SeedTestDataAsync(); await using var connection = new NpgsqlConnection(_connectionString); await connection.OpenAsync(); await using var txn = await connection.BeginTransactionAsync(); var clubs = (await connection.QueryAsync( "SELECT * FROM clubs")).ToList(); await txn.CommitAsync(); Assert.Empty(clubs); } [Fact] public async Task RLS_AllowsAccess_WithCorrectTenantContext() { await SeedTestDataAsync(); await using var connection = new NpgsqlConnection(_connectionString); await connection.OpenAsync(); await using var txn = await connection.BeginTransactionAsync(); await connection.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'"); var clubs = (await connection.QueryAsync( "SELECT * FROM clubs WHERE \"TenantId\" = 'tenant-1'")).ToList(); await txn.CommitAsync(); Assert.NotEmpty(clubs); Assert.All(clubs, c => Assert.Equal("tenant-1", c.TenantId)); } [Fact] public async Task RLS_IsolatesData_AcrossTenants() { await SeedTestDataAsync(); await using var conn1 = new NpgsqlConnection(_connectionString); await conn1.OpenAsync(); await using var txn1 = await conn1.BeginTransactionAsync(); await conn1.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'"); var tenant1Clubs = (await conn1.QueryAsync( "SELECT * FROM clubs")).ToList(); await txn1.CommitAsync(); await using var conn2 = new NpgsqlConnection(_connectionString); await conn2.OpenAsync(); await using var txn2 = await conn2.BeginTransactionAsync(); await conn2.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-2'"); var tenant2Clubs = (await conn2.QueryAsync( "SELECT * FROM clubs")).ToList(); await txn2.CommitAsync(); Assert.NotEmpty(tenant1Clubs); Assert.NotEmpty(tenant2Clubs); Assert.All(tenant1Clubs, c => Assert.Equal("tenant-1", c.TenantId)); Assert.All(tenant2Clubs, c => Assert.Equal("tenant-2", c.TenantId)); var tenant1Ids = tenant1Clubs.Select(c => c.Id).ToHashSet(); var tenant2Ids = tenant2Clubs.Select(c => c.Id).ToHashSet(); Assert.Empty(tenant1Ids.Intersect(tenant2Ids)); } [Fact] public async Task RLS_CountsCorrectly_PerTenant() { await SeedTestDataAsync(); await using var conn1 = new NpgsqlConnection(_connectionString); await conn1.OpenAsync(); await using var txn1 = await conn1.BeginTransactionAsync(); await conn1.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'"); var tenant1Count = await conn1.ExecuteScalarAsync( "SELECT COUNT(*) FROM work_items"); await txn1.CommitAsync(); await using var conn2 = new NpgsqlConnection(_connectionString); await conn2.OpenAsync(); await using var txn2 = await conn2.BeginTransactionAsync(); await conn2.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-2'"); var tenant2Count = await conn2.ExecuteScalarAsync( "SELECT COUNT(*) FROM work_items"); await txn2.CommitAsync(); Assert.Equal(5, tenant1Count); Assert.Equal(3, tenant2Count); } [Fact] public async Task RLS_AllowsBypass_ForAdminRole() { await SeedTestDataAsync(); await using var connection = new NpgsqlConnection(_adminConnectionString); await connection.OpenAsync(); await using var txn = await connection.BeginTransactionAsync(); var allClubs = (await connection.QueryAsync( "SELECT * FROM clubs")).ToList(); await txn.CommitAsync(); Assert.True(allClubs.Count >= 2); Assert.Contains(allClubs, c => c.TenantId == "tenant-1"); Assert.Contains(allClubs, c => c.TenantId == "tenant-2"); } [Fact] public async Task RLS_HandlesShiftSignups_WithSubquery() { await SeedTestDataAsync(); await using var connection = new NpgsqlConnection(_connectionString); await connection.OpenAsync(); await using var txn = await connection.BeginTransactionAsync(); await connection.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'"); var signups = (await connection.QueryAsync( "SELECT * FROM shift_signups")).ToList(); await txn.CommitAsync(); Assert.NotEmpty(signups); Assert.All(signups, s => Assert.Equal("tenant-1", s.TenantId)); } private async Task SeedTestDataAsync() { var options = new DbContextOptionsBuilder() .UseNpgsql(_adminConnectionString) .Options; await using var context = new AppDbContext(options); await context.Database.MigrateAsync(); await using var adminConn = new NpgsqlConnection(_adminConnectionString); await adminConn.OpenAsync(); await using var txn = await adminConn.BeginTransactionAsync(); await adminConn.ExecuteAsync(@" GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO rls_test_user; GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO rls_test_user; "); await adminConn.ExecuteAsync(@" ALTER TABLE clubs ENABLE ROW LEVEL SECURITY; ALTER TABLE clubs FORCE ROW LEVEL SECURITY; ALTER TABLE members ENABLE ROW LEVEL SECURITY; ALTER TABLE members FORCE ROW LEVEL SECURITY; ALTER TABLE work_items ENABLE ROW LEVEL SECURITY; ALTER TABLE work_items FORCE ROW LEVEL SECURITY; ALTER TABLE shifts ENABLE ROW LEVEL SECURITY; ALTER TABLE shifts FORCE ROW LEVEL SECURITY; ALTER TABLE shift_signups ENABLE ROW LEVEL SECURITY; ALTER TABLE shift_signups FORCE ROW LEVEL SECURITY; "); await adminConn.ExecuteAsync(@" DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='clubs' AND policyname='tenant_isolation_policy') THEN CREATE POLICY tenant_isolation_policy ON clubs FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true)); END IF; IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='members' AND policyname='tenant_isolation_policy') THEN CREATE POLICY tenant_isolation_policy ON members FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true)); END IF; IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='work_items' AND policyname='tenant_isolation_policy') THEN CREATE POLICY tenant_isolation_policy ON work_items FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true)); END IF; IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='shifts' AND policyname='tenant_isolation_policy') THEN CREATE POLICY tenant_isolation_policy ON shifts FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true)); END IF; IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='shift_signups' AND policyname='tenant_isolation_policy') THEN CREATE POLICY tenant_isolation_policy ON shift_signups FOR ALL USING (""ShiftId"" IN (SELECT ""Id"" FROM shifts WHERE (""TenantId"")::text = current_setting('app.current_tenant_id', true))); END IF; END $$; "); var club1Id = Guid.NewGuid(); var club2Id = Guid.NewGuid(); await adminConn.ExecuteAsync(@" INSERT INTO clubs (""Id"", ""TenantId"", ""Name"", ""SportType"", ""CreatedAt"", ""UpdatedAt"") VALUES (@Id1, 'tenant-1', 'Club 1', 0, NOW(), NOW()), (@Id2, 'tenant-2', 'Club 2', 1, NOW(), NOW())", new { Id1 = club1Id, Id2 = club2Id }); await adminConn.ExecuteAsync(@" INSERT INTO work_items (""Id"", ""TenantId"", ""Title"", ""Status"", ""CreatedById"", ""ClubId"", ""CreatedAt"", ""UpdatedAt"") SELECT gen_random_uuid(), 'tenant-1', 'Task ' || i, 0, gen_random_uuid(), @ClubId, NOW(), NOW() FROM generate_series(1, 5) i", new { ClubId = club1Id }); await adminConn.ExecuteAsync(@" INSERT INTO work_items (""Id"", ""TenantId"", ""Title"", ""Status"", ""CreatedById"", ""ClubId"", ""CreatedAt"", ""UpdatedAt"") SELECT gen_random_uuid(), 'tenant-2', 'Task ' || i, 0, gen_random_uuid(), @ClubId, NOW(), NOW() FROM generate_series(1, 3) i", new { ClubId = club2Id }); var shift1Id = Guid.NewGuid(); await adminConn.ExecuteAsync(@" INSERT INTO shifts (""Id"", ""TenantId"", ""Title"", ""StartTime"", ""EndTime"", ""ClubId"", ""CreatedById"", ""CreatedAt"", ""UpdatedAt"") VALUES (@Id, 'tenant-1', 'Shift 1', NOW(), NOW() + interval '2 hours', @ClubId, gen_random_uuid(), NOW(), NOW())", new { Id = shift1Id, ClubId = club1Id }); await adminConn.ExecuteAsync(@" INSERT INTO shift_signups (""Id"", ""TenantId"", ""ShiftId"", ""MemberId"", ""SignedUpAt"") VALUES (gen_random_uuid(), 'tenant-1', @ShiftId, gen_random_uuid(), NOW())", new { ShiftId = shift1Id }); await txn.CommitAsync(); } }