Files

270 lines
11 KiB
C#
Raw Permalink Normal View History

using Dapper;
using Microsoft.EntityFrameworkCore;
using Npgsql;
using Testcontainers.PostgreSql;
using WorkClub.Domain.Entities;
using WorkClub.Infrastructure.Data;
namespace WorkClub.Tests.Integration.Data;
public class RlsTests : IAsyncLifetime
{
private PostgreSqlContainer? _container;
private string? _connectionString;
private string? _adminConnectionString;
private string? _rlsUserConnectionString;
public async Task InitializeAsync()
{
_container = new PostgreSqlBuilder()
.WithImage("postgres:16-alpine")
.WithDatabase("workclub")
.WithUsername("app_user")
.WithPassword("apppass")
.Build();
await _container.StartAsync();
_adminConnectionString = _container.GetConnectionString();
await using var adminConn = new NpgsqlConnection(_adminConnectionString);
await adminConn.OpenAsync();
await adminConn.ExecuteAsync(@"
CREATE USER rls_test_user WITH PASSWORD 'rlspass';
GRANT CONNECT ON DATABASE workclub TO rls_test_user;
");
var builder = new NpgsqlConnectionStringBuilder(_adminConnectionString)
{
Username = "rls_test_user",
Password = "rlspass"
};
_connectionString = builder.ConnectionString;
_rlsUserConnectionString = _connectionString;
}
public async Task DisposeAsync()
{
if (_container != null)
{
await _container.DisposeAsync();
}
}
[Fact]
public async Task RLS_BlocksAccess_WithoutTenantContext()
{
await SeedTestDataAsync();
await using var connection = new NpgsqlConnection(_connectionString);
await connection.OpenAsync();
await using var txn = await connection.BeginTransactionAsync();
var clubs = (await connection.QueryAsync<Club>(
"SELECT * FROM clubs")).ToList();
await txn.CommitAsync();
Assert.Empty(clubs);
}
[Fact]
public async Task RLS_AllowsAccess_WithCorrectTenantContext()
{
await SeedTestDataAsync();
await using var connection = new NpgsqlConnection(_connectionString);
await connection.OpenAsync();
await using var txn = await connection.BeginTransactionAsync();
await connection.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'");
var clubs = (await connection.QueryAsync<Club>(
"SELECT * FROM clubs WHERE \"TenantId\" = 'tenant-1'")).ToList();
await txn.CommitAsync();
Assert.NotEmpty(clubs);
Assert.All(clubs, c => Assert.Equal("tenant-1", c.TenantId));
}
[Fact]
public async Task RLS_IsolatesData_AcrossTenants()
{
await SeedTestDataAsync();
await using var conn1 = new NpgsqlConnection(_connectionString);
await conn1.OpenAsync();
await using var txn1 = await conn1.BeginTransactionAsync();
await conn1.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'");
var tenant1Clubs = (await conn1.QueryAsync<Club>(
"SELECT * FROM clubs")).ToList();
await txn1.CommitAsync();
await using var conn2 = new NpgsqlConnection(_connectionString);
await conn2.OpenAsync();
await using var txn2 = await conn2.BeginTransactionAsync();
await conn2.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-2'");
var tenant2Clubs = (await conn2.QueryAsync<Club>(
"SELECT * FROM clubs")).ToList();
await txn2.CommitAsync();
Assert.NotEmpty(tenant1Clubs);
Assert.NotEmpty(tenant2Clubs);
Assert.All(tenant1Clubs, c => Assert.Equal("tenant-1", c.TenantId));
Assert.All(tenant2Clubs, c => Assert.Equal("tenant-2", c.TenantId));
var tenant1Ids = tenant1Clubs.Select(c => c.Id).ToHashSet();
var tenant2Ids = tenant2Clubs.Select(c => c.Id).ToHashSet();
Assert.Empty(tenant1Ids.Intersect(tenant2Ids));
}
[Fact]
public async Task RLS_CountsCorrectly_PerTenant()
{
await SeedTestDataAsync();
await using var conn1 = new NpgsqlConnection(_connectionString);
await conn1.OpenAsync();
await using var txn1 = await conn1.BeginTransactionAsync();
await conn1.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'");
var tenant1Count = await conn1.ExecuteScalarAsync<int>(
"SELECT COUNT(*) FROM work_items");
await txn1.CommitAsync();
await using var conn2 = new NpgsqlConnection(_connectionString);
await conn2.OpenAsync();
await using var txn2 = await conn2.BeginTransactionAsync();
await conn2.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-2'");
var tenant2Count = await conn2.ExecuteScalarAsync<int>(
"SELECT COUNT(*) FROM work_items");
await txn2.CommitAsync();
Assert.Equal(5, tenant1Count);
Assert.Equal(3, tenant2Count);
}
[Fact]
public async Task RLS_AllowsBypass_ForAdminRole()
{
await SeedTestDataAsync();
await using var connection = new NpgsqlConnection(_adminConnectionString);
await connection.OpenAsync();
await using var txn = await connection.BeginTransactionAsync();
var allClubs = (await connection.QueryAsync<Club>(
"SELECT * FROM clubs")).ToList();
await txn.CommitAsync();
Assert.True(allClubs.Count >= 2);
Assert.Contains(allClubs, c => c.TenantId == "tenant-1");
Assert.Contains(allClubs, c => c.TenantId == "tenant-2");
}
[Fact]
public async Task RLS_HandlesShiftSignups_WithSubquery()
{
await SeedTestDataAsync();
await using var connection = new NpgsqlConnection(_connectionString);
await connection.OpenAsync();
await using var txn = await connection.BeginTransactionAsync();
await connection.ExecuteAsync("SET LOCAL app.current_tenant_id = 'tenant-1'");
var signups = (await connection.QueryAsync<ShiftSignup>(
"SELECT * FROM shift_signups")).ToList();
await txn.CommitAsync();
Assert.NotEmpty(signups);
Assert.All(signups, s => Assert.Equal("tenant-1", s.TenantId));
}
private async Task SeedTestDataAsync()
{
var options = new DbContextOptionsBuilder<AppDbContext>()
.UseNpgsql(_adminConnectionString)
.Options;
await using var context = new AppDbContext(options);
await context.Database.MigrateAsync();
await using var adminConn = new NpgsqlConnection(_adminConnectionString);
await adminConn.OpenAsync();
await using var txn = await adminConn.BeginTransactionAsync();
await adminConn.ExecuteAsync(@"
GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO rls_test_user;
GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO rls_test_user;
");
await adminConn.ExecuteAsync(@"
ALTER TABLE clubs ENABLE ROW LEVEL SECURITY;
ALTER TABLE clubs FORCE ROW LEVEL SECURITY;
ALTER TABLE members ENABLE ROW LEVEL SECURITY;
ALTER TABLE members FORCE ROW LEVEL SECURITY;
ALTER TABLE work_items ENABLE ROW LEVEL SECURITY;
ALTER TABLE work_items FORCE ROW LEVEL SECURITY;
ALTER TABLE shifts ENABLE ROW LEVEL SECURITY;
ALTER TABLE shifts FORCE ROW LEVEL SECURITY;
ALTER TABLE shift_signups ENABLE ROW LEVEL SECURITY;
ALTER TABLE shift_signups FORCE ROW LEVEL SECURITY;
");
await adminConn.ExecuteAsync(@"
DO $$ BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='clubs' AND policyname='tenant_isolation_policy') THEN
CREATE POLICY tenant_isolation_policy ON clubs FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true));
END IF;
IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='members' AND policyname='tenant_isolation_policy') THEN
CREATE POLICY tenant_isolation_policy ON members FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true));
END IF;
IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='work_items' AND policyname='tenant_isolation_policy') THEN
CREATE POLICY tenant_isolation_policy ON work_items FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true));
END IF;
IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='shifts' AND policyname='tenant_isolation_policy') THEN
CREATE POLICY tenant_isolation_policy ON shifts FOR ALL USING ((""TenantId"")::text = current_setting('app.current_tenant_id', true));
END IF;
IF NOT EXISTS (SELECT 1 FROM pg_policies WHERE tablename='shift_signups' AND policyname='tenant_isolation_policy') THEN
CREATE POLICY tenant_isolation_policy ON shift_signups FOR ALL USING (""ShiftId"" IN (SELECT ""Id"" FROM shifts WHERE (""TenantId"")::text = current_setting('app.current_tenant_id', true)));
END IF;
END $$;
");
var club1Id = Guid.NewGuid();
var club2Id = Guid.NewGuid();
await adminConn.ExecuteAsync(@"
INSERT INTO clubs (""Id"", ""TenantId"", ""Name"", ""SportType"", ""CreatedAt"", ""UpdatedAt"")
VALUES (@Id1, 'tenant-1', 'Club 1', 0, NOW(), NOW()),
(@Id2, 'tenant-2', 'Club 2', 1, NOW(), NOW())",
new { Id1 = club1Id, Id2 = club2Id });
await adminConn.ExecuteAsync(@"
INSERT INTO work_items (""Id"", ""TenantId"", ""Title"", ""Status"", ""CreatedById"", ""ClubId"", ""CreatedAt"", ""UpdatedAt"")
SELECT gen_random_uuid(), 'tenant-1', 'Task ' || i, 0, gen_random_uuid(), @ClubId, NOW(), NOW()
FROM generate_series(1, 5) i",
new { ClubId = club1Id });
await adminConn.ExecuteAsync(@"
INSERT INTO work_items (""Id"", ""TenantId"", ""Title"", ""Status"", ""CreatedById"", ""ClubId"", ""CreatedAt"", ""UpdatedAt"")
SELECT gen_random_uuid(), 'tenant-2', 'Task ' || i, 0, gen_random_uuid(), @ClubId, NOW(), NOW()
FROM generate_series(1, 3) i",
new { ClubId = club2Id });
var shift1Id = Guid.NewGuid();
await adminConn.ExecuteAsync(@"
INSERT INTO shifts (""Id"", ""TenantId"", ""Title"", ""StartTime"", ""EndTime"", ""ClubId"", ""CreatedById"", ""CreatedAt"", ""UpdatedAt"")
VALUES (@Id, 'tenant-1', 'Shift 1', NOW(), NOW() + interval '2 hours', @ClubId, gen_random_uuid(), NOW(), NOW())",
new { Id = shift1Id, ClubId = club1Id });
await adminConn.ExecuteAsync(@"
INSERT INTO shift_signups (""Id"", ""TenantId"", ""ShiftId"", ""MemberId"", ""SignedUpAt"")
VALUES (gen_random_uuid(), 'tenant-1', @ShiftId, gen_random_uuid(), NOW())",
new { ShiftId = shift1Id });
await txn.CommitAsync();
}
}