diff --git a/cmd/routing-api/main_test.go b/cmd/routing-api/main_test.go index 48860980..6584a3fa 100644 --- a/cmd/routing-api/main_test.go +++ b/cmd/routing-api/main_test.go @@ -2,6 +2,8 @@ package main_test import ( "fmt" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" "net/http" "os" "os/exec" @@ -11,7 +13,6 @@ import ( "code.cloudfoundry.org/routing-api/cmd/routing-api/testrunner" "code.cloudfoundry.org/routing-api/db" "code.cloudfoundry.org/routing-api/models" - "github.com/jinzhu/gorm" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" . "github.com/onsi/gomega/gbytes" @@ -19,6 +20,7 @@ import ( "github.com/onsi/gomega/ghttp" "github.com/tedsuo/ifrit" ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" + "gorm.io/gorm" ) const ( @@ -162,7 +164,8 @@ var _ = Describe("Main", func() { rapiConfig := getRoutingAPIConfig(defaultConfig) connectionString, err := db.ConnectionString(&rapiConfig.SqlDB) Expect(err).NotTo(HaveOccurred()) - gormDB, err := gorm.Open(rapiConfig.SqlDB.Type, connectionString) + + gormDB, err := gorm.Open(getGormDialect(rapiConfig.SqlDB.Type, connectionString), &gorm.Config{}) Expect(err).NotTo(HaveOccurred()) getRoutes := func() string { @@ -239,13 +242,13 @@ var _ = Describe("Main", func() { } connectionString, err := db.ConnectionString(&rapiConfig.SqlDB) Expect(err).NotTo(HaveOccurred()) - gormDB, err = gorm.Open(rapiConfig.SqlDB.Type, connectionString) + gormDB, err = gorm.Open(getGormDialect(rapiConfig.SqlDB.Type, connectionString), &gorm.Config{}) Expect(err).NotTo(HaveOccurred()) }) - AfterEach(func() { - gormDB.AutoMigrate(&models.RouterGroupDB{}) - Expect(os.Remove(configPath)).To(Succeed()) - }) + /* AfterEach(func() { + gormDB.AutoMigrate(&models.RouterGroupDB{}) + Expect(os.Remove(configPath)).To(Succeed()) + })*/ It("should fail with an error", func() { routingAPIRunner := testrunner.New(routingAPIBinPath, routingAPIArgs) proc := ifrit.Invoke(routingAPIRunner) @@ -265,3 +268,16 @@ func RoutingApi(args ...string) *Session { return session } + +func getGormDialect(databaseType string, connectionString string) gorm.Dialector { + var dialect gorm.Dialector + + switch databaseType { + case "postgres": + dialect = postgres.Open(connectionString) + case "mysql": + dialect = mysql.Open(connectionString) + } + + return dialect +} diff --git a/cmd/routing-api/routing_api_suite_test.go b/cmd/routing-api/routing_api_suite_test.go index 20bf079e..e6d9749e 100644 --- a/cmd/routing-api/routing_api_suite_test.go +++ b/cmd/routing-api/routing_api_suite_test.go @@ -32,6 +32,7 @@ import ( "github.com/onsi/gomega/ghttp" "google.golang.org/grpc/grpclog" yaml "gopkg.in/yaml.v2" + "gorm.io/gorm" ) var ( @@ -63,9 +64,10 @@ var ( mtlsAPIClientCert tls.Certificate ) -func TestMain(t *testing.T) { +func TestMainSuite(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "Main Suite") + suiteConfig, reporterConfig := GinkgoConfiguration() + RunSpecs(t, "Main Suite", suiteConfig, reporterConfig) } var _ = SynchronizedBeforeSuite( diff --git a/cmd/routing-api/testrunner/db.go b/cmd/routing-api/testrunner/db.go index 06803780..2cf0ec2a 100644 --- a/cmd/routing-api/testrunner/db.go +++ b/cmd/routing-api/testrunner/db.go @@ -10,9 +10,9 @@ import ( "code.cloudfoundry.org/routing-api/db" "code.cloudfoundry.org/routing-api/config" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" . "github.com/onsi/ginkgo/v2" + _ "gorm.io/driver/mysql" + _ "gorm.io/driver/postgres" ) type DbAllocator interface { diff --git a/db/client.go b/db/client.go index 56e1b8e3..05e63f6e 100644 --- a/db/client.go +++ b/db/client.go @@ -3,7 +3,7 @@ package db import ( "database/sql" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) //go:generate counterfeiter -o fakes/fake_client.go . Client @@ -13,7 +13,7 @@ type Client interface { Create(value interface{}) (int64, error) Delete(value interface{}, where ...interface{}) (int64, error) Save(value interface{}) (int64, error) - Update(attrs ...interface{}) (int64, error) + Update(column string, value interface{}) (int64, error) First(out interface{}, where ...interface{}) error Find(out interface{}, where ...interface{}) error AutoMigrate(values ...interface{}) error @@ -21,8 +21,8 @@ type Client interface { Rollback() error Commit() error HasTable(value interface{}) bool - AddUniqueIndex(indexName string, columns ...string) (Client, error) - RemoveIndex(indexName string) (Client, error) + AddUniqueIndex(indexName string, columns interface{}) error + RemoveIndex(indexName string, columns interface{}) error Model(value interface{}) Client Exec(query string, args ...interface{}) int64 Rows(tableName string) (*sql.Rows, error) @@ -37,21 +37,17 @@ func NewGormClient(db *gorm.DB) Client { return &gormClient{db: db} } func (c *gormClient) DropColumn(name string) error { - return c.db.DropColumn(name).Error + return c.DropColumn(name) } func (c *gormClient) Close() error { - return c.db.Close() + return c.Close() } -func (c *gormClient) AddUniqueIndex(indexName string, columns ...string) (Client, error) { - var newClient gormClient - newClient.db = c.db.AddUniqueIndex(indexName, columns...) - return &newClient, newClient.db.Error +func (c *gormClient) AddUniqueIndex(indexName string, columns interface{}) error { + return c.db.Migrator().CreateIndex(columns, indexName) } -func (c *gormClient) RemoveIndex(indexName string) (Client, error) { - var newClient gormClient - newClient.db = c.db.RemoveIndex(indexName) - return &newClient, newClient.db.Error +func (c *gormClient) RemoveIndex(indexName string, columns interface{}) error { + return c.db.Migrator().DropIndex(columns, indexName) } func (c *gormClient) Model(value interface{}) Client { @@ -80,8 +76,8 @@ func (c *gormClient) Save(value interface{}) (int64, error) { return newDb.RowsAffected, newDb.Error } -func (c *gormClient) Update(attrs ...interface{}) (int64, error) { - newDb := c.db.Update(attrs...) +func (c *gormClient) Update(column string, value interface{}) (int64, error) { + newDb := c.db.Update(column, value) return newDb.RowsAffected, newDb.Error } @@ -94,7 +90,7 @@ func (c *gormClient) Find(out interface{}, where ...interface{}) error { } func (c *gormClient) AutoMigrate(values ...interface{}) error { - return c.db.AutoMigrate(values...).Error + return c.db.AutoMigrate(values...) } func (c *gormClient) Begin() Client { @@ -112,7 +108,7 @@ func (c *gormClient) Commit() error { } func (c *gormClient) HasTable(value interface{}) bool { - return c.db.HasTable(value) + return c.db.Migrator().HasTable(value) } func (c *gormClient) Exec(query string, args ...interface{}) int64 { diff --git a/db/db_sql.go b/db/db_sql.go index 9c9ee458..c4d32bad 100644 --- a/db/db_sql.go +++ b/db/db_sql.go @@ -4,6 +4,9 @@ import ( "context" "errors" "fmt" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "io/ioutil" "os" "path/filepath" "runtime" @@ -16,9 +19,9 @@ import ( "code.cloudfoundry.org/routing-api/config" "code.cloudfoundry.org/routing-api/models" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" + _ "gorm.io/driver/mysql" + _ "gorm.io/driver/postgres" + "gorm.io/gorm" ) //go:generate counterfeiter -o fakes/fake_db.go . DB @@ -104,24 +107,36 @@ func NewSqlDB(cfg *config.SqlDB) (*SqlDB, error) { return nil, errors.New("SQL configuration cannot be nil") } - if cfg.Type != "mysql" && cfg.Type != "postgres" { - return &SqlDB{}, fmt.Errorf("Unknown type %s", cfg.Type) + connStr, err := ConnectionString(cfg) + if err != nil { + return nil, err } - connStr, err := ConnectionString(cfg) + var dialect gorm.Dialector + switch cfg.Type { + case "postgres": + dialect = postgres.Open(connStr) + case "mysql": + dialect = mysql.Open(connStr) + default: + return &SqlDB{}, errors.New(fmt.Sprintf("Unknown type %s", cfg.Type)) + } + + db, err := gorm.Open(dialect, &gorm.Config{}) if err != nil { return nil, err } - db, err := gorm.Open(cfg.Type, connStr) + // Use the connection pool and setup it + sqlDB, err := db.DB() if err != nil { return nil, err } - db.DB().SetMaxIdleConns(cfg.MaxIdleConns) - db.DB().SetMaxOpenConns(cfg.MaxOpenConns) + sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) + sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) connMaxLifetime := time.Duration(cfg.ConnMaxLifetime) * time.Second - db.DB().SetConnMaxLifetime(connMaxLifetime) + sqlDB.SetConnMaxLifetime(connMaxLifetime) tcpEventHub := eventhub.NewNonBlocking(1024) httpEventHub := eventhub.NewNonBlocking(1024) diff --git a/db/db_suite_test.go b/db/db_suite_test.go index 6d7038f6..46ffd53a 100644 --- a/db/db_suite_test.go +++ b/db/db_suite_test.go @@ -5,7 +5,7 @@ import ( "code.cloudfoundry.org/routing-api/cmd/routing-api/testrunner" "code.cloudfoundry.org/routing-api/config" - _ "github.com/lib/pq" + _ "github.com/jackc/pgx/v5/stdlib" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/db/fakes/fake_client.go b/db/fakes/fake_client.go index 1eb71b5d..b3cc0da7 100644 --- a/db/fakes/fake_client.go +++ b/db/fakes/fake_client.go @@ -9,19 +9,17 @@ import ( ) type FakeClient struct { - AddUniqueIndexStub func(string, ...string) (db.Client, error) + AddUniqueIndexStub func(string, interface{}) error addUniqueIndexMutex sync.RWMutex addUniqueIndexArgsForCall []struct { arg1 string - arg2 []string + arg2 interface{} } addUniqueIndexReturns struct { - result1 db.Client - result2 error + result1 error } addUniqueIndexReturnsOnCall map[int]struct { - result1 db.Client - result2 error + result1 error } AutoMigrateStub func(...interface{}) error autoMigrateMutex sync.RWMutex @@ -160,18 +158,17 @@ type FakeClient struct { modelReturnsOnCall map[int]struct { result1 db.Client } - RemoveIndexStub func(string) (db.Client, error) + RemoveIndexStub func(string, interface{}) error removeIndexMutex sync.RWMutex removeIndexArgsForCall []struct { arg1 string + arg2 interface{} } removeIndexReturns struct { - result1 db.Client - result2 error + result1 error } removeIndexReturnsOnCall map[int]struct { - result1 db.Client - result2 error + result1 error } RollbackStub func() error rollbackMutex sync.RWMutex @@ -209,10 +206,11 @@ type FakeClient struct { result1 int64 result2 error } - UpdateStub func(...interface{}) (int64, error) + UpdateStub func(string, interface{}) (int64, error) updateMutex sync.RWMutex updateArgsForCall []struct { - arg1 []interface{} + arg1 string + arg2 interface{} } updateReturns struct { result1 int64 @@ -238,23 +236,23 @@ type FakeClient struct { invocationsMutex sync.RWMutex } -func (fake *FakeClient) AddUniqueIndex(arg1 string, arg2 ...string) (db.Client, error) { +func (fake *FakeClient) AddUniqueIndex(arg1 string, arg2 interface{}) error { fake.addUniqueIndexMutex.Lock() ret, specificReturn := fake.addUniqueIndexReturnsOnCall[len(fake.addUniqueIndexArgsForCall)] fake.addUniqueIndexArgsForCall = append(fake.addUniqueIndexArgsForCall, struct { arg1 string - arg2 []string + arg2 interface{} }{arg1, arg2}) fake.recordInvocation("AddUniqueIndex", []interface{}{arg1, arg2}) fake.addUniqueIndexMutex.Unlock() if fake.AddUniqueIndexStub != nil { - return fake.AddUniqueIndexStub(arg1, arg2...) + return fake.AddUniqueIndexStub(arg1, arg2) } if specificReturn { - return ret.result1, ret.result2 + return ret.result1 } fakeReturns := fake.addUniqueIndexReturns - return fakeReturns.result1, fakeReturns.result2 + return fakeReturns.result1 } func (fake *FakeClient) AddUniqueIndexCallCount() int { @@ -263,43 +261,40 @@ func (fake *FakeClient) AddUniqueIndexCallCount() int { return len(fake.addUniqueIndexArgsForCall) } -func (fake *FakeClient) AddUniqueIndexCalls(stub func(string, ...string) (db.Client, error)) { +func (fake *FakeClient) AddUniqueIndexCalls(stub func(string, interface{}) error) { fake.addUniqueIndexMutex.Lock() defer fake.addUniqueIndexMutex.Unlock() fake.AddUniqueIndexStub = stub } -func (fake *FakeClient) AddUniqueIndexArgsForCall(i int) (string, []string) { +func (fake *FakeClient) AddUniqueIndexArgsForCall(i int) (string, interface{}) { fake.addUniqueIndexMutex.RLock() defer fake.addUniqueIndexMutex.RUnlock() argsForCall := fake.addUniqueIndexArgsForCall[i] return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeClient) AddUniqueIndexReturns(result1 db.Client, result2 error) { +func (fake *FakeClient) AddUniqueIndexReturns(result1 error) { fake.addUniqueIndexMutex.Lock() defer fake.addUniqueIndexMutex.Unlock() fake.AddUniqueIndexStub = nil fake.addUniqueIndexReturns = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } -func (fake *FakeClient) AddUniqueIndexReturnsOnCall(i int, result1 db.Client, result2 error) { +func (fake *FakeClient) AddUniqueIndexReturnsOnCall(i int, result1 error) { fake.addUniqueIndexMutex.Lock() defer fake.addUniqueIndexMutex.Unlock() fake.AddUniqueIndexStub = nil if fake.addUniqueIndexReturnsOnCall == nil { fake.addUniqueIndexReturnsOnCall = make(map[int]struct { - result1 db.Client - result2 error + result1 error }) } fake.addUniqueIndexReturnsOnCall[i] = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } func (fake *FakeClient) AutoMigrate(arg1 ...interface{}) error { @@ -1008,22 +1003,23 @@ func (fake *FakeClient) ModelReturnsOnCall(i int, result1 db.Client) { }{result1} } -func (fake *FakeClient) RemoveIndex(arg1 string) (db.Client, error) { +func (fake *FakeClient) RemoveIndex(arg1 string, arg2 interface{}) error { fake.removeIndexMutex.Lock() ret, specificReturn := fake.removeIndexReturnsOnCall[len(fake.removeIndexArgsForCall)] fake.removeIndexArgsForCall = append(fake.removeIndexArgsForCall, struct { arg1 string - }{arg1}) + arg2 interface{} + }{arg1, arg2}) fake.recordInvocation("RemoveIndex", []interface{}{arg1}) fake.removeIndexMutex.Unlock() if fake.RemoveIndexStub != nil { - return fake.RemoveIndexStub(arg1) + return fake.RemoveIndexStub(arg1, arg2) } if specificReturn { - return ret.result1, ret.result2 + return ret.result1 } fakeReturns := fake.removeIndexReturns - return fakeReturns.result1, fakeReturns.result2 + return fakeReturns.result1 } func (fake *FakeClient) RemoveIndexCallCount() int { @@ -1032,7 +1028,7 @@ func (fake *FakeClient) RemoveIndexCallCount() int { return len(fake.removeIndexArgsForCall) } -func (fake *FakeClient) RemoveIndexCalls(stub func(string) (db.Client, error)) { +func (fake *FakeClient) RemoveIndexCalls(stub func(string, interface{}) error) { fake.removeIndexMutex.Lock() defer fake.removeIndexMutex.Unlock() fake.RemoveIndexStub = stub @@ -1045,30 +1041,27 @@ func (fake *FakeClient) RemoveIndexArgsForCall(i int) string { return argsForCall.arg1 } -func (fake *FakeClient) RemoveIndexReturns(result1 db.Client, result2 error) { +func (fake *FakeClient) RemoveIndexReturns(result1 error) { fake.removeIndexMutex.Lock() defer fake.removeIndexMutex.Unlock() fake.RemoveIndexStub = nil fake.removeIndexReturns = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } -func (fake *FakeClient) RemoveIndexReturnsOnCall(i int, result1 db.Client, result2 error) { +func (fake *FakeClient) RemoveIndexReturnsOnCall(i int, result1 error) { fake.removeIndexMutex.Lock() defer fake.removeIndexMutex.Unlock() fake.RemoveIndexStub = nil if fake.removeIndexReturnsOnCall == nil { fake.removeIndexReturnsOnCall = make(map[int]struct { - result1 db.Client - result2 error + result1 error }) } fake.removeIndexReturnsOnCall[i] = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } func (fake *FakeClient) Rollback() error { @@ -1249,16 +1242,17 @@ func (fake *FakeClient) SaveReturnsOnCall(i int, result1 int64, result2 error) { }{result1, result2} } -func (fake *FakeClient) Update(arg1 ...interface{}) (int64, error) { +func (fake *FakeClient) Update(arg1 string, arg2 interface{}) (int64, error) { fake.updateMutex.Lock() ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)] fake.updateArgsForCall = append(fake.updateArgsForCall, struct { - arg1 []interface{} - }{arg1}) - fake.recordInvocation("Update", []interface{}{arg1}) + arg1 string + arg2 interface{} + }{arg1, arg2}) + fake.recordInvocation("Update", []interface{}{arg1, arg2}) fake.updateMutex.Unlock() if fake.UpdateStub != nil { - return fake.UpdateStub(arg1...) + return fake.UpdateStub(arg1, arg2) } if specificReturn { return ret.result1, ret.result2 @@ -1273,17 +1267,17 @@ func (fake *FakeClient) UpdateCallCount() int { return len(fake.updateArgsForCall) } -func (fake *FakeClient) UpdateCalls(stub func(...interface{}) (int64, error)) { +func (fake *FakeClient) UpdateCalls(stub func(string, interface{}) (int64, error)) { fake.updateMutex.Lock() defer fake.updateMutex.Unlock() fake.UpdateStub = stub } -func (fake *FakeClient) UpdateArgsForCall(i int) []interface{} { +func (fake *FakeClient) UpdateArgsForCall(i int) (string, interface{}) { fake.updateMutex.RLock() defer fake.updateMutex.RUnlock() argsForCall := fake.updateArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeClient) UpdateReturns(result1 int64, result2 error) { diff --git a/migration/V2_update_rg_migration.go b/migration/V2_update_rg_migration.go index 5bc1757d..99597557 100644 --- a/migration/V2_update_rg_migration.go +++ b/migration/V2_update_rg_migration.go @@ -18,6 +18,9 @@ func (v *V2UpdateRgMigration) Version() int { } func (v *V2UpdateRgMigration) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.RouterGroup{}).AddUniqueIndex("idx_rg_name", "name") - return err + type routerGroup struct { + models.Model + Name string `gorm:"size:255;index:idx_rg_name,unique" json:"name"` + } + return sqlDB.Client.AddUniqueIndex("idx_rg_name", &routerGroup{}) } diff --git a/migration/V4_add_rg_uniq_idx_tcp_route_migration.go b/migration/V4_add_rg_uniq_idx_tcp_route_migration.go index dace6c00..e916f2f7 100644 --- a/migration/V4_add_rg_uniq_idx_tcp_route_migration.go +++ b/migration/V4_add_rg_uniq_idx_tcp_route_migration.go @@ -18,10 +18,11 @@ func (v *V4AddRgUniqIdxTCPRoute) Version() int { } func (v *V4AddRgUniqIdxTCPRoute) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") + err := sqlDB.Client.RemoveIndex("idx_tcp_route", &models.TcpRouteMapping{}) if err != nil { return err } - _, err = sqlDB.Client.Model(&models.TcpRouteMapping{}).AddUniqueIndex("idx_tcp_route", "router_group_guid", "host_port", "host_ip", "external_port") + + err = sqlDB.Client.AddUniqueIndex("idx_tcp_route", &models.TcpRouteMapping{}) return err } diff --git a/migration/V5_sni_hostname_migration.go b/migration/V5_sni_hostname_migration.go index f7626643..ba2bea1c 100644 --- a/migration/V5_sni_hostname_migration.go +++ b/migration/V5_sni_hostname_migration.go @@ -18,7 +18,7 @@ func (v *V5SniHostnameMigration) Version() int { } func (v *V5SniHostnameMigration) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") + err := sqlDB.Client.RemoveIndex("idx_tcp_route", &models.TcpRouteMapping{}) if err != nil { return err } diff --git a/migration/migration.go b/migration/migration.go index 92674f5a..fb83f3a7 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -5,7 +5,7 @@ import ( "code.cloudfoundry.org/lager/v3" "code.cloudfoundry.org/routing-api/db" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) const MigrationKey = "routing-api-migration"