diff --git a/pkg/postgres/access.go b/pkg/postgres/access.go index e00aaa8d..fd9bea2d 100644 --- a/pkg/postgres/access.go +++ b/pkg/postgres/access.go @@ -3,63 +3,44 @@ package postgres import ( "context" "database/sql" - "strings" ) -var prepareDdlStatements = []string{ - "alter default privileges in schema public grant CHANGEME on tables to cloudsqliamuser;", - "alter default privileges in schema public grant CHANGEME on sequences to cloudsqliamuser;", - "grant CHANGEME on all tables in schema public to cloudsqliamuser;", - "grant CHANGEME on all sequences in schema public to cloudsqliamuser;", -} - -func PrepareAccess(ctx context.Context, appName, namespace, cluster, database string, allPrivs bool) error { - dbInfo, err := NewDBInfo(appName, namespace, cluster, database) - if err != nil { - return err - } - - connectionInfo, err := dbInfo.DBConnection(ctx) - if err != nil { - return err - } - - db, err := sql.Open("cloudsqlpostgres", connectionInfo.ConnectionString()) - if err != nil { - return err - } - defer db.Close() +var grantAllPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO cloudsqliamuser; + ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO cloudsqliamuser; + GRANT ALL ON ALL TABLES IN SCHEMA public TO cloudsqliamuser; + GRANT ALL ON ALL SEQUENCES IN SCHEMA public TO cloudsqliamuser; + GRANT CREATE ON SCHEMA public TO cloudsqliamuser;` - for _, ddl := range prepareDdlStatements { - _, err = db.ExecContext(ctx, setGrant(ddl, allPrivs)) - if err != nil { - return err - } - } +var grantSelectPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO cloudsqliamuser; + ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON SEQUENCES TO cloudsqliamuser; + GRANT SELECT ON ALL TABLES IN SCHEMA public TO cloudsqliamuser; + GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO cloudsqliamuser;` - return nil -} +// this is used for all privileges and select, as it covers both cases +var revokeAllPrivs = `ALTER DEFAULT PRIVILEGES IN SCHEMA public REVOKE ALL ON TABLES FROM cloudsqliamuser; + ALTER DEFAULT PRIVILEGES IN SCHEMA public REVOKE ALL ON SEQUENCES FROM cloudsqliamuser; + REVOKE ALL ON ALL TABLES IN SCHEMA public FROM cloudsqliamuser; + REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM cloudsqliamuser; + REVOKE CREATE ON SCHEMA public FROM cloudsqliamuser;` -func setGrant(sql string, allPrivs bool) string { - sqlGrant := "SELECT" +func PrepareAccess(ctx context.Context, appName, namespace, cluster, database string, allPrivs bool) error { if allPrivs { - sqlGrant = "ALL" + return sqlExecAsAppUser(ctx, appName, namespace, cluster, database, grantAllPrivs) + } else { + return sqlExecAsAppUser(ctx, appName, namespace, cluster, database, grantSelectPrivs) } - return strings.Replace(sql, "CHANGEME", sqlGrant, 1) } -var revokeDdlStatements = []string{ - "alter default privileges in schema public revoke ALL on tables from cloudsqliamuser;", - "alter default privileges in schema public revoke ALL on sequences from cloudsqliamuser;", - "revoke ALL on all tables in schema public from cloudsqliamuser;", - "revoke ALL on all sequences in schema public from cloudsqliamuser;", +func RevokeAccess(ctx context.Context, appName, namespace, cluster, database string) error { + return sqlExecAsAppUser(ctx, appName, namespace, cluster, database, revokeAllPrivs) } -func RevokeAccess(ctx context.Context, appName, namespace, cluster, database string) error { +func sqlExecAsAppUser(ctx context.Context, appName, namespace, cluster, database, statement string) error { dbInfo, err := NewDBInfo(appName, namespace, cluster, database) if err != nil { return err } + connectionInfo, err := dbInfo.DBConnection(ctx) if err != nil { return err @@ -71,11 +52,9 @@ func RevokeAccess(ctx context.Context, appName, namespace, cluster, database str } defer db.Close() - for _, ddl := range revokeDdlStatements { - _, err = db.ExecContext(ctx, ddl) - if err != nil { - return formatInvalidGrantError(err) - } + _, err = db.ExecContext(ctx, statement) + if err != nil { + return formatInvalidGrantError(err) } return nil diff --git a/pkg/postgres/access_test.go b/pkg/postgres/access_test.go deleted file mode 100644 index bdd15aff..00000000 --- a/pkg/postgres/access_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package postgres - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestPermissionsAll(t *testing.T) { - expected := "blabla whatever ALL stuff" - actual := setGrant("blabla whatever CHANGEME stuff", true) - assert.Equal(t, expected, actual) -} - -func TestPermissionsOther(t *testing.T) { - expected := "blabla whatever SELECT stuff" - actual := setGrant("blabla whatever CHANGEME stuff", false) - assert.Equal(t, expected, actual) -}