diff --git a/internal/conn/pool.go b/internal/conn/pool.go index 18974bc5e..f839d30f0 100644 --- a/internal/conn/pool.go +++ b/internal/conn/pool.go @@ -21,11 +21,12 @@ import ( ) type Pool struct { - usages int64 - config Config - dialOptions []grpc.DialOption - conns xsync.Map[string, *conn] - done chan struct{} + usages int64 + config Config + dialOptions []grpc.DialOption + conns xsync.Map[string, *conn] + done chan struct{} + parkerStoped chan struct{} } func (p *Pool) DialTimeout() time.Duration { @@ -160,6 +161,7 @@ func (p *Pool) Release(ctx context.Context) (finalErr error) { ) wg.Add(cap(errCh)) + p.conns.Range(func(_ string, c *conn) bool { go func(c closer.Closer) { defer wg.Done() @@ -170,6 +172,9 @@ func (p *Pool) Release(ctx context.Context) (finalErr error) { return true }) + + <-p.parkerStoped + wg.Wait() close(errCh) @@ -186,8 +191,11 @@ func (p *Pool) Release(ctx context.Context) (finalErr error) { } func (p *Pool) connParker(ctx context.Context, ttl, interval time.Duration) { + defer close(p.parkerStoped) + ticker := time.NewTicker(interval) defer ticker.Stop() + for { select { case <-p.done: @@ -216,10 +224,11 @@ func NewPool(ctx context.Context, config Config) *Pool { defer onDone() p := &Pool{ - usages: 1, - config: config, - dialOptions: config.GrpcDialOptions(), - done: make(chan struct{}), + usages: 1, + config: config, + dialOptions: config.GrpcDialOptions(), + done: make(chan struct{}), + parkerStoped: make(chan struct{}), } p.dialOptions = append(p.dialOptions, @@ -248,6 +257,8 @@ func NewPool(ctx context.Context, config Config) *Pool { if ttl := config.ConnectionTTL(); ttl > 0 { go p.connParker(xcontext.ValueOnly(ctx), ttl, ttl/2) //nolint:gomnd + } else { + close(p.parkerStoped) } return p diff --git a/internal/query/execute_query.go b/internal/query/execute_query.go index ceaa74245..a5c245d5d 100644 --- a/internal/query/execute_query.go +++ b/internal/query/execute_query.go @@ -2,6 +2,7 @@ package query import ( "context" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xruntime" "io" "time" @@ -146,6 +147,10 @@ func execute( return nil, xerrors.WithStackTrace(err) } + xruntime.AddCleanup(r, func(cancelStream context.CancelFunc) { + cancelStream() + }, executeCancel) + return r, nil } diff --git a/internal/query/result.go b/internal/query/result.go index cb2a194b4..1d6246a70 100644 --- a/internal/query/result.go +++ b/internal/query/result.go @@ -228,20 +228,20 @@ func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) ( func (r *streamResult) Close(ctx context.Context) (finalErr error) { defer r.closeOnce() - if r.trace != nil { - onDone := trace.QueryOnResultClose(r.trace, &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).Close"), - ) - defer func() { - onDone(finalErr) - }() - } - for { select { case <-r.closed: return nil default: + if r.trace != nil { + onDone := trace.QueryOnResultClose(r.trace, &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).Close"), + ) + defer func() { + onDone(finalErr) + }() + } + _, err := r.nextPart(ctx) if err != nil { if xerrors.Is(err, io.EOF) { diff --git a/internal/query/session.go b/internal/query/session.go index 2bc02d533..b58ba0669 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -2,7 +2,6 @@ package query import ( "context" - "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" diff --git a/internal/query/session_core.go b/internal/query/session_core.go index 04ec4a9cb..9c375e118 100644 --- a/internal/query/session_core.go +++ b/internal/query/session_core.go @@ -3,8 +3,6 @@ package query import ( "context" "os" - "runtime/pprof" - "strconv" "sync/atomic" "time" @@ -38,10 +36,11 @@ type ( SetStatus(code Status) } sessionCore struct { - cc grpc.ClientConnInterface - Client Ydb_Query_V1.QueryServiceClient - Trace *trace.Query - done chan struct{} + cc grpc.ClientConnInterface + Client Ydb_Query_V1.QueryServiceClient + Trace *trace.Query + done chan struct{} + attachStreamExited chan struct{} deleteTimeout time.Duration id string @@ -120,9 +119,10 @@ func Open( ctx context.Context, client Ydb_Query_V1.QueryServiceClient, opts ...Option, ) (_ *sessionCore, finalErr error) { core := &sessionCore{ - Client: client, - Trace: &trace.Query{}, - done: make(chan struct{}), + Client: client, + Trace: &trace.Query{}, + done: make(chan struct{}), + attachStreamExited: make(chan struct{}), } for _, opt := range opts { @@ -200,6 +200,8 @@ func (core *sessionCore) attach(ctx context.Context) (finalErr error) { core.closeOnce = xsync.OnceFunc(func(ctx context.Context) error { defer cancelAttach() + <-core.attachStreamExited + core.SetStatus(StatusClosing) defer core.SetStatus(StatusClosed) @@ -210,31 +212,18 @@ func (core *sessionCore) attach(ctx context.Context) (finalErr error) { return nil }) - if markGoroutineWithLabelNodeIDForAttachStream { - pprof.Do(ctx, pprof.Labels( - "node_id", strconv.Itoa(int(core.NodeID())), - ), func(context.Context) { - go core.listenAttachStream(attachStream) - }) - } else { - go core.listenAttachStream(attachStream) - } + go core.listenAttachStream(attachStream) return nil } func (core *sessionCore) listenAttachStream(attachStream Ydb_Query_V1.QueryService_AttachSessionClient) { defer func() { - select { - case <-core.done: - return - default: - close(core.done) - } + close(core.attachStreamExited) }() for core.IsAlive() { - if _, recvErr := attachStream.Recv(); recvErr != nil { + if s, recvErr := attachStream.Recv(); recvErr != nil || s.GetStatus() != Ydb.StatusIds_SUCCESS { return } } diff --git a/internal/repeater/repeater.go b/internal/repeater/repeater.go index 01f9a6ec8..2861cf719 100644 --- a/internal/repeater/repeater.go +++ b/internal/repeater/repeater.go @@ -8,7 +8,6 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) @@ -29,8 +28,8 @@ type repeater struct { // Task is a function that must be executed periodically. task func(context.Context) error - cancel context.CancelFunc - stopped chan struct{} + done chan struct{} + workerStopped chan struct{} force chan struct{} clock clockwork.Clock @@ -96,16 +95,14 @@ func New( task func(ctx context.Context) (err error), opts ...option, ) *repeater { - ctx, cancel := xcontext.WithCancel(ctx) - r := &repeater{ - interval: interval, - task: task, - cancel: cancel, - stopped: make(chan struct{}), - force: make(chan struct{}, 1), - clock: clockwork.NewRealClock(), - trace: &trace.Driver{}, + interval: interval, + task: task, + done: make(chan struct{}), + workerStopped: make(chan struct{}), + force: make(chan struct{}, 1), + clock: clockwork.NewRealClock(), + trace: &trace.Driver{}, } for _, opt := range opts { @@ -114,17 +111,17 @@ func New( } } - go r.worker(ctx, r.clock.NewTicker(interval)) + go r.worker(r.clock.NewTicker(interval)) return r } func (r *repeater) stop(onCancel func()) { - r.cancel() + close(r.done) if onCancel != nil { onCancel() } - <-r.stopped + <-r.workerStopped } // Stop stops to execute its task. @@ -162,11 +159,9 @@ func (r *repeater) wakeUp(e Event) (err error) { return r.task(ctx) } -func (r *repeater) worker(ctx context.Context, tick clockwork.Ticker) { - defer func() { - close(r.stopped) - tick.Stop() - }() +func (r *repeater) worker(tick clockwork.Ticker) { + defer close(r.workerStopped) + defer tick.Stop() // force returns backoff with delays [500ms...32s] force := backoff.New( @@ -187,7 +182,7 @@ func (r *repeater) worker(ctx context.Context, tick clockwork.Ticker) { defer force.Stop() select { - case <-ctx.Done(): + case <-r.done: return EventCancel case <-tick.Chan(): return EventTick @@ -210,7 +205,7 @@ func (r *repeater) worker(ctx context.Context, tick clockwork.Ticker) { for { select { - case <-ctx.Done(): + case <-r.done: return case <-tick.Chan(): diff --git a/internal/xtest/context.go b/internal/xtest/context.go index 318cac0ba..ab03cc80c 100644 --- a/internal/xtest/context.go +++ b/internal/xtest/context.go @@ -2,10 +2,9 @@ package xtest import ( "context" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "runtime/pprof" "testing" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" ) func Context(t testing.TB) context.Context { diff --git a/internal/xtest/logger.go b/internal/xtest/logger.go index a712cfb75..4dbedd0a0 100644 --- a/internal/xtest/logger.go +++ b/internal/xtest/logger.go @@ -1,11 +1,15 @@ package xtest import ( + "regexp" + "runtime" + "strings" "sync" "testing" + "time" ) -func MakeSyncedTest(t *testing.T) *SyncedTest { +func MakeSyncedTest(t *testing.T) (st *SyncedTest) { return &SyncedTest{ T: t, } @@ -16,6 +20,64 @@ type SyncedTest struct { *testing.T } +func (s *SyncedTest) checkGoroutinesLeak() { + var ( + bb = make([]byte, 2<<32) + currentGoroutine string + ) + + time.Sleep(time.Millisecond) + + if n := runtime.Stack(bb, false); n < len(bb) { + currentGoroutine = string(regexp.MustCompile("^goroutine \\d+ ").Find(bb[:n])) + } + + if n := runtime.Stack(bb, true); n < len(bb) { + bb = bb[:n] + } + + goroutines := strings.Split(string(bb), "\n\n") + unexpectedGoroutines := make([]string, 0, len(goroutines)) + + for _, g := range goroutines { + if strings.HasPrefix(g, currentGoroutine) { + continue + } + stack := strings.Split(g, "\n") + firstFunction := stack[1] + state := strings.Trim(regexp.MustCompile("\\[.*\\]").FindString(regexp.MustCompile("^goroutine \\d+ \\[.*\\]").FindString(stack[0])), "[]") + switch { + case strings.HasPrefix(firstFunction, "testing.RunTests"), + strings.HasPrefix(firstFunction, "testing.(*T).Run"), + strings.HasPrefix(firstFunction, "testing.(*T).Parallel"), + strings.HasPrefix(firstFunction, "testing.runFuzzing"), + strings.HasPrefix(firstFunction, "testing.runFuzzTests"): + if strings.Contains(state, "chan receive") { + continue + } + + case strings.HasPrefix(firstFunction, "runtime.goexit") && state == "syscall": + continue + + case strings.HasPrefix(firstFunction, "os/signal.signal_recv"), + strings.HasPrefix(firstFunction, "os/signal.loop"): + if strings.Contains(g, "runtime.ensureSigM") { + continue + } + + //case strings.HasPrefix(firstFunction, "syscall.syscall"): + // if strings.Contains(state, "syscall") { + // continue + // } + } + + unexpectedGoroutines = append(unexpectedGoroutines, g) + } + if l := len(unexpectedGoroutines); l > 0 { + s.T.Errorf("found %d unexpected goroutines:\n%s", l, strings.Join(unexpectedGoroutines, "\n")) + } +} + func (s *SyncedTest) Cleanup(f func()) { s.m.Lock() defer s.m.Unlock() @@ -80,13 +142,6 @@ func (s *SyncedTest) Fatalf(format string, args ...interface{}) { s.T.Fatalf(format, args...) } -// must direct called -// func (s *SyncedTest) Helper() { -// s.m.Lock() -// defer s.m.Unlock() -// s.T.Helper() -//} - func (s *SyncedTest) Log(args ...interface{}) { s.m.Lock() defer s.m.Unlock() @@ -125,6 +180,8 @@ func (s *SyncedTest) RunSynced(name string, f func(t *SyncedTest)) bool { s.T.Helper() return s.T.Run(name, func(t *testing.T) { + defer s.checkGoroutinesLeak() + syncedTest := MakeSyncedTest(t) f(syncedTest) }) diff --git a/tests/integration/basic_example_database_sql_bindings_test.go b/tests/integration/basic_example_database_sql_bindings_test.go index 9c191263f..c13282808 100644 --- a/tests/integration/basic_example_database_sql_bindings_test.go +++ b/tests/integration/basic_example_database_sql_bindings_test.go @@ -26,8 +26,6 @@ import ( ) func TestBasicExampleDatabaseSqlBindings(t *testing.T) { - defer simpleDetectGoroutineLeak(t) - folder := t.Name() ctx, cancel := context.WithTimeout(xtest.Context(t), 42*time.Second) @@ -63,10 +61,7 @@ func TestBasicExampleDatabaseSqlBindings(t *testing.T) { ) require.NoError(t, err) - defer func() { - // cleanup - _ = nativeDriver.Close(ctx) - }() + defer nativeDriver.Close(ctx) c, err := ydb.Connector(nativeDriver, ydb.WithTablePathPrefix(path.Join(nativeDriver.Name(), folder)), @@ -74,17 +69,10 @@ func TestBasicExampleDatabaseSqlBindings(t *testing.T) { ydb.WithPositionalArgs(), ) require.NoError(t, err) - - defer func() { - // cleanup - _ = c.Close() - }() + defer c.Close() db := sql.OpenDB(c) - defer func() { - // cleanup - _ = db.Close() - }() + defer db.Close() err = db.PingContext(ctx) require.NoError(t, err) diff --git a/tests/integration/basic_example_database_sql_test.go b/tests/integration/basic_example_database_sql_test.go index ed1b91b0a..4d61135ba 100644 --- a/tests/integration/basic_example_database_sql_test.go +++ b/tests/integration/basic_example_database_sql_test.go @@ -27,8 +27,6 @@ import ( ) func TestBasicExampleDatabaseSql(t *testing.T) { - defer simpleDetectGoroutineLeak(t) - folder := t.Name() ctx, cancel := context.WithTimeout(xtest.Context(t), 42*time.Second) @@ -90,24 +88,16 @@ func TestBasicExampleDatabaseSql(t *testing.T) { ydb.WithDiscoveryInterval(time.Second), ) require.NoError(t, err) - - defer func() { - require.NoError(t, nativeDriver.Close(ctx)) - }() + defer nativeDriver.Close(ctx) c, err := ydb.Connector(nativeDriver, ydb.WithQueryService(tt.useQueryService), ) require.NoError(t, err) - - defer func() { - require.NoError(t, c.Close()) - }() + defer c.Close() db := sql.OpenDB(c) - defer func() { - require.NoError(t, db.Close()) - }() + defer db.Close() require.NoError(t, db.PingContext(ctx)) diff --git a/tests/integration/basic_example_native_test.go b/tests/integration/basic_example_native_test.go index be4270400..2fb3ecf96 100644 --- a/tests/integration/basic_example_native_test.go +++ b/tests/integration/basic_example_native_test.go @@ -103,11 +103,7 @@ func TestBasicExampleNative(sourceTest *testing.T) { //nolint:gocyclo if err != nil { t.Fatal(err) } - - defer func() { - // cleanup - _ = db.Close(ctx) - }() + defer db.Close(ctx) if err = db.Table().Do(ctx, func(ctx context.Context, _ table.Session) error { // hack for wait pool initializing diff --git a/tests/integration/connection_with_compression_test.go b/tests/integration/connection_with_compression_test.go index 5a677ab16..ec97fa64a 100644 --- a/tests/integration/connection_with_compression_test.go +++ b/tests/integration/connection_with_compression_test.go @@ -114,12 +114,8 @@ func TestConnectionWithCompression(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("close failed: %+v", e) - } - }() + defer db.Close(ctx) + t.Run("discovery.WhoAmI", func(t *testing.T) { if err = retry.Retry(ctx, func(ctx context.Context) (err error) { discoveryClient := Ydb_Discovery_V1.NewDiscoveryServiceClient(ydb.GRPCConn(db)) diff --git a/tests/integration/database_sql_containers_test.go b/tests/integration/database_sql_containers_test.go index 5e8ae71f8..d8198da31 100644 --- a/tests/integration/database_sql_containers_test.go +++ b/tests/integration/database_sql_containers_test.go @@ -28,9 +28,7 @@ func TestDatabaseSqlContainers(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) - defer func() { - _ = nativeDriver.Close(ctx) - }() + defer nativeDriver.Close(ctx) connector, err := ydb.Connector(nativeDriver) require.NoError(t, err) @@ -39,6 +37,7 @@ func TestDatabaseSqlContainers(t *testing.T) { }() db := sql.OpenDB(connector) + defer db.Close() err = retry.Do(ctx, db, func(ctx context.Context, cc *sql.Conn) error { rows, err := cc.QueryContext(ctx, ` diff --git a/tests/integration/database_sql_ddl_in_transaction_test.go b/tests/integration/database_sql_ddl_in_transaction_test.go index 7c2057e12..906a2ccf7 100644 --- a/tests/integration/database_sql_ddl_in_transaction_test.go +++ b/tests/integration/database_sql_ddl_in_transaction_test.go @@ -20,10 +20,6 @@ func TestDatabaseSqlDDLInTransaction(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - f := func(ctx context.Context, tx *sql.Tx) (err error) { _, err = tx.ExecContext( ydb.WithQueryMode(ctx, ydb.SchemeQueryMode), @@ -80,6 +76,7 @@ func TestDatabaseSqlDDLInTransaction(t *testing.T) { require.NoError(t, err) db := sql.OpenDB(connector) + defer db.Close() err = db.PingContext(scope.Ctx) require.NoError(t, err) diff --git a/tests/integration/database_sql_get_column_type_test.go b/tests/integration/database_sql_get_column_type_test.go index 009ef41fe..c5090bda7 100644 --- a/tests/integration/database_sql_get_column_type_test.go +++ b/tests/integration/database_sql_get_column_type_test.go @@ -24,10 +24,6 @@ func TestDatabaseSqlGetColumnType(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("create-tables", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { _, err = cc.ExecContext( @@ -130,10 +126,6 @@ func TestDatabaseSqlColumnTypes(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - columns := []struct { YQL string Nullable bool diff --git a/tests/integration/database_sql_get_columns_test.go b/tests/integration/database_sql_get_columns_test.go index cdf85a2a2..ba386c894 100644 --- a/tests/integration/database_sql_get_columns_test.go +++ b/tests/integration/database_sql_get_columns_test.go @@ -21,10 +21,6 @@ func TestDatabaseSqlGetColumns(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("create-tables", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { _, err = cc.ExecContext( diff --git a/tests/integration/database_sql_get_index_columns_test.go b/tests/integration/database_sql_get_index_columns_test.go index 4d376fb7e..263b16822 100644 --- a/tests/integration/database_sql_get_index_columns_test.go +++ b/tests/integration/database_sql_get_index_columns_test.go @@ -21,10 +21,6 @@ func TestDatabaseSqlGetIndexColumns(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("create-tables", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { _, err = cc.ExecContext( diff --git a/tests/integration/database_sql_get_indexes_test.go b/tests/integration/database_sql_get_indexes_test.go index 343a3f811..6383d4c3f 100644 --- a/tests/integration/database_sql_get_indexes_test.go +++ b/tests/integration/database_sql_get_indexes_test.go @@ -21,10 +21,6 @@ func TestDatabaseSqlGetIndexes(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("create-tables", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { _, err = cc.ExecContext( diff --git a/tests/integration/database_sql_get_primary_keys_test.go b/tests/integration/database_sql_get_primary_keys_test.go index 6e3fc33e6..0e0f1b05f 100644 --- a/tests/integration/database_sql_get_primary_keys_test.go +++ b/tests/integration/database_sql_get_primary_keys_test.go @@ -21,10 +21,6 @@ func TestDatabaseSqlGetPrimaryKeys(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("create-tables", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { _, err = cc.ExecContext( diff --git a/tests/integration/database_sql_get_tables_test.go b/tests/integration/database_sql_get_tables_test.go index 219a1acee..6069af9fb 100644 --- a/tests/integration/database_sql_get_tables_test.go +++ b/tests/integration/database_sql_get_tables_test.go @@ -25,10 +25,6 @@ func TestDatabaseSqlGetTables(t *testing.T) { folder = t.Name() ) - defer func() { - _ = db.Close() - }() - t.Run("prepare-sub-folder", func(t *testing.T) { cc, err := ydb.Unwrap(db) require.NoError(t, err) @@ -193,10 +189,6 @@ func TestDatabaseSqlGetTablesRecursive(t *testing.T) { folder = t.Name() ) - defer func() { - _ = db.Close() - }() - t.Run("prepare-sub-folder", func(t *testing.T) { cc, err := ydb.Unwrap(db) require.NoError(t, err) diff --git a/tests/integration/database_sql_is_column_exists_test.go b/tests/integration/database_sql_is_column_exists_test.go index 74bab9bb4..6389ff33f 100644 --- a/tests/integration/database_sql_is_column_exists_test.go +++ b/tests/integration/database_sql_is_column_exists_test.go @@ -21,10 +21,6 @@ func TestDatabaseSqlIsColumnExists(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("create-tables", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { _, err = cc.ExecContext( diff --git a/tests/integration/database_sql_is_primary_key_test.go b/tests/integration/database_sql_is_primary_key_test.go index 4e2228434..98afe1b88 100644 --- a/tests/integration/database_sql_is_primary_key_test.go +++ b/tests/integration/database_sql_is_primary_key_test.go @@ -21,10 +21,6 @@ func TestDatabaseSqlIsPrimaryKey(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("create-tables", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { _, err = cc.ExecContext( diff --git a/tests/integration/database_sql_is_table_exists_test.go b/tests/integration/database_sql_is_table_exists_test.go index 6676e01df..a68a8d328 100644 --- a/tests/integration/database_sql_is_table_exists_test.go +++ b/tests/integration/database_sql_is_table_exists_test.go @@ -21,10 +21,6 @@ func TestDatabaseSqlIsTableExists(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("drop-if-exists", func(t *testing.T) { err := retry.Do(scope.Ctx, db, func(ctx context.Context, cc *sql.Conn) (err error) { exists := true diff --git a/tests/integration/database_sql_scanner_test.go b/tests/integration/database_sql_scanner_test.go index 7cb1adedd..15797be4e 100644 --- a/tests/integration/database_sql_scanner_test.go +++ b/tests/integration/database_sql_scanner_test.go @@ -37,9 +37,7 @@ func TestDatabaseSqlScanner(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) - defer func() { - _ = nativeDriver.Close(ctx) - }() + defer nativeDriver.Close(ctx) connector, err := ydb.Connector(nativeDriver, ydb.WithQueryService(false), @@ -51,6 +49,7 @@ func TestDatabaseSqlScanner(t *testing.T) { }() db1 = sql.OpenDB(connector) + defer db1.Close() } { nativeDriver, err := ydb.Open(ctx, @@ -58,9 +57,7 @@ func TestDatabaseSqlScanner(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) - defer func() { - _ = nativeDriver.Close(ctx) - }() + defer nativeDriver.Close(ctx) connector, err := ydb.Connector(nativeDriver, ydb.WithQueryService(true), @@ -72,6 +69,7 @@ func TestDatabaseSqlScanner(t *testing.T) { }() db2 = sql.OpenDB(connector) + defer db2.Close() } for _, ttt := range []struct { name string diff --git a/tests/integration/database_sql_scheme_test.go b/tests/integration/database_sql_scheme_test.go index cb16114d8..45bb27025 100644 --- a/tests/integration/database_sql_scheme_test.go +++ b/tests/integration/database_sql_scheme_test.go @@ -19,10 +19,6 @@ func TestDatabaseSqlScheme(t *testing.T) { db = scope.SQLDriverWithFolder() ) - defer func() { - _ = db.Close() - }() - t.Run("drop-tables", func(t *testing.T) { cc, err := db.Conn(scope.Ctx) require.NoError(t, err) diff --git a/tests/integration/database_sql_static_credentials_test.go b/tests/integration/database_sql_static_credentials_test.go index 2badc214c..eef49f0b3 100644 --- a/tests/integration/database_sql_static_credentials_test.go +++ b/tests/integration/database_sql_static_credentials_test.go @@ -71,10 +71,7 @@ func TestDatabaseSqlStaticCredentials(t *testing.T) { ) require.NoError(t, err) - defer func() { - // cleanup - _ = cc.Close(ctx) - }() + defer cc.Close(ctx) c, err := ydb.Connector(cc) require.NoError(t, err) @@ -85,10 +82,7 @@ func TestDatabaseSqlStaticCredentials(t *testing.T) { }() db := sql.OpenDB(c) - defer func() { - // cleanup - _ = db.Close() - }() + defer db.Close() err = db.PingContext(ctx) require.NoError(t, err) diff --git a/tests/integration/database_sql_with_tx_control_test.go b/tests/integration/database_sql_with_tx_control_test.go index 42157808d..09c1a79ff 100644 --- a/tests/integration/database_sql_with_tx_control_test.go +++ b/tests/integration/database_sql_with_tx_control_test.go @@ -44,7 +44,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { table.SerializableReadWriteTxControl(), ), db, func(ctx context.Context, cc *sql.Conn) error { - _, err := db.QueryContext(ctx, "SELECT 1") + rows, err := db.QueryContext(ctx, "SELECT 1") + defer rows.Close() + return err }, )) @@ -62,7 +64,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { table.SerializableReadWriteTxControl(), ), db, func(ctx context.Context, cc *sql.Conn) error { - _, err := db.QueryContext(ctx, "SELECT 1") + rows, err := db.QueryContext(ctx, "SELECT 1") + defer rows.Close() + return err }, )) @@ -80,7 +84,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { table.SnapshotReadOnlyTxControl(), ), db, func(ctx context.Context, cc *sql.Conn) error { - _, err := db.QueryContext(ctx, "SELECT 1") + rows, err := db.QueryContext(ctx, "SELECT 1") + defer rows.Close() + return err }, )) @@ -98,7 +104,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { table.StaleReadOnlyTxControl(), ), db, func(ctx context.Context, cc *sql.Conn) error { - _, err := db.QueryContext(ctx, "SELECT 1") + rows, err := db.QueryContext(ctx, "SELECT 1") + defer rows.Close() + return err }, )) @@ -116,7 +124,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { table.OnlineReadOnlyTxControl(), ), db, func(ctx context.Context, cc *sql.Conn) error { - _, err := db.QueryContext(ctx, "SELECT 1") + rows, err := db.QueryContext(ctx, "SELECT 1") + defer rows.Close() + return err }, )) @@ -134,7 +144,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { table.OnlineReadOnlyTxControl(table.WithInconsistentReads()), ), db, func(ctx context.Context, cc *sql.Conn) error { - _, err := db.QueryContext(ctx, "SELECT 1") + rows, err := db.QueryContext(ctx, "SELECT 1") + defer rows.Close() + return err }, )) diff --git a/tests/integration/discovery_test.go b/tests/integration/discovery_test.go index 3a315a7a0..41eb9ddf7 100644 --- a/tests/integration/discovery_test.go +++ b/tests/integration/discovery_test.go @@ -106,12 +106,8 @@ func TestDiscovery(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("db close failed: %+v", e) - } - }() + defer db.Close(ctx) + t.Run("discovery.Discover", func(t *testing.T) { if endpoints, err := db.Discovery().Discover(ctx); err != nil { t.Fatal(err) diff --git a/tests/integration/driver_secure_test.go b/tests/integration/driver_secure_test.go index 9eec250bb..0030a6289 100644 --- a/tests/integration/driver_secure_test.go +++ b/tests/integration/driver_secure_test.go @@ -58,12 +58,8 @@ func TestConnectionSecure(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("close failed: %+v", e) - } - }() + defer db.Close(ctx) + t.Run("discovery.WhoAmI", func(t *testing.T) { if err = retry.Retry(ctx, func(ctx context.Context) (err error) { discoveryClient := Ydb_Discovery_V1.NewDiscoveryServiceClient(ydb.GRPCConn(db)) diff --git a/tests/integration/driver_test.go b/tests/integration/driver_test.go index 590e9abb4..49de09661 100644 --- a/tests/integration/driver_test.go +++ b/tests/integration/driver_test.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path" + "strings" "testing" "time" @@ -41,11 +42,34 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) -//nolint:gocyclo -func TestDriver(sourceTest *testing.T) { +func TestNew(sourceTest *testing.T) { t := xtest.MakeSyncedTest(sourceTest) - const sumColumn = "sum" + + ctx := xtest.Context(t) + + db1, err := ydb.New(ctx, //nolint:gocritic + ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ydb.With( + config.WithOperationTimeout(time.Second*2), + config.WithOperationCancelAfter(time.Second*2), + ), + ydb.WithConnectionTTL(time.Millisecond*10000), + ydb.WithMinTLSVersion(tls.VersionTLS10), + ydb.WithLogger( + newLogger(t), + trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), + ), + ) + require.NoError(t, err) + require.NoError(t, db1.Close(ctx)) +} + +func TestOpen(sourceTest *testing.T) { var ( + t = xtest.MakeSyncedTest(sourceTest) userAgent = "connection user agent" requestType = "connection request type" traceParentID = "test-traceparent-id" @@ -89,373 +113,387 @@ func TestDriver(sourceTest *testing.T) { ctx = meta.WithTraceParent(xtest.Context(t), traceParentID) ) - t.RunSynced("ydb.New", func(t *xtest.SyncedTest) { - db, err := ydb.New(ctx, //nolint:gocritic - ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")), - ydb.WithAccessTokenCredentials( - os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), - ), - ydb.With( - config.WithOperationTimeout(time.Second*2), - config.WithOperationCancelAfter(time.Second*2), - ), - ydb.WithConnectionTTL(time.Millisecond*10000), - ydb.WithMinTLSVersion(tls.VersionTLS10), - ydb.WithLogger( - newLogger(t), - trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ydb.With( + config.WithOperationTimeout(time.Second*2), + config.WithOperationCancelAfter(time.Second*2), + ), + ydb.WithConnectionTTL(time.Millisecond*10000), + ydb.WithMinTLSVersion(tls.VersionTLS10), + ydb.WithLogger( + newLoggerWithMinLevel(t, log.WARN), + trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), + ), + ydb.WithApplicationName(userAgent), + ydb.WithRequestsType(requestType), + ydb.With( + config.WithGrpcOptions( + grpc.WithUnaryInterceptor(func( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + checkMetadata(ctx) + return invoker(ctx, method, req, reply, cc, opts...) + }), + grpc.WithStreamInterceptor(func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + checkMetadata(ctx) + return streamer(ctx, desc, cc, method, opts...) + }), ), + ), + ) + require.NoError(t, err) + require.NoError(t, db.Close(ctx)) +} + +func TestWithStaticCredentialsInConnectionString(sourceTest *testing.T) { + t := xtest.MakeSyncedTest(sourceTest) + + ctx := xtest.Context(t) + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithDialTimeout(time.Second), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ) + require.NoError(t, err) + defer db.Close(ctx) + + userName := strings.ToLower(t.Name()) + + err = db.Query().Exec(ctx, `DROP USER IF EXISTS `+userName) + require.NoError(t, err) + err = db.Query().Exec(ctx, "CREATE USER "+userName+" PASSWORD 'password'; ALTER GROUP `ADMINS` ADD USER "+userName+";") + require.NoError(t, err) + defer func() { + _ = db.Query().Exec(ctx, `DROP USER `+userName) + }() + + t.RunSynced("HappyWay", func(t *xtest.SyncedTest) { + u, err := url.Parse(os.Getenv("YDB_CONNECTION_STRING")) + require.NoError(t, err) + + u.User = url.UserPassword(userName, "password") + t.Log(u.String()) + + test, err := ydb.Open(ctx, u.String()) + require.NoError(t, err) + defer test.Close(ctx) + + row, err := test.Query().QueryRow(ctx, `SELECT 1`) + require.NoError(t, err) + + var v int + err = row.Scan(&v) + require.NoError(t, err) + + tableName := path.Join(test.Name(), t.Name(), "test") + + err = test.Query().Exec(ctx, fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id Uint64, + value Utf8, + PRIMARY KEY (id) + )`, "`"+tableName+"`"), ) - if err != nil { - t.Fatal(err) - } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("close failed: %+v", e) + require.NoError(t, err) + + var d options.Description + err = test.Table().Do(ctx, func(ctx context.Context, s table.Session) error { + d, err = s.DescribeTable(ctx, tableName) + if err != nil { + return err } - }() + + return nil + }) + require.NoError(t, err) + + require.Equal(t, "test", d.Name) + require.Equal(t, 2, len(d.Columns)) + require.Equal(t, "id", d.Columns[0].Name) + require.Equal(t, "value", d.Columns[1].Name) + require.Equal(t, []string{"id"}, d.PrimaryKey) }) - t.RunSynced("ydb.Open", func(t *xtest.SyncedTest) { - db, err := ydb.Open(ctx, - os.Getenv("YDB_CONNECTION_STRING"), - ydb.WithAccessTokenCredentials( - os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), - ), - ydb.With( - config.WithOperationTimeout(time.Second*2), - config.WithOperationCancelAfter(time.Second*2), - ), - ydb.WithConnectionTTL(time.Millisecond*10000), - ydb.WithMinTLSVersion(tls.VersionTLS10), - ydb.WithLogger( - newLoggerWithMinLevel(t, log.WARN), - trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`), - ), - ydb.WithApplicationName(userAgent), - ydb.WithRequestsType(requestType), - ydb.With( - config.WithGrpcOptions( - grpc.WithUnaryInterceptor(func( - ctx context.Context, - method string, - req, reply interface{}, - cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, - opts ...grpc.CallOption, - ) error { - checkMetadata(ctx) - return invoker(ctx, method, req, reply, cc, opts...) - }), - grpc.WithStreamInterceptor(func( - ctx context.Context, - desc *grpc.StreamDesc, - cc *grpc.ClientConn, - method string, - streamer grpc.Streamer, - opts ...grpc.CallOption, - ) (grpc.ClientStream, error) { - checkMetadata(ctx) - return streamer(ctx, desc, cc, method, opts...) - }), - ), - ), - ) - if err != nil { - t.Fatal(err) - } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("close failed: %+v", e) - } - }() + t.RunSynced("WrongLogin", func(t *xtest.SyncedTest) { + u, err := url.Parse(os.Getenv("YDB_CONNECTION_STRING")) + require.NoError(t, err) + u.User = url.UserPassword("wrong_login", "password") + test, err := ydb.Open(ctx, u.String()) + require.Error(t, err) + require.Nil(t, test) + require.True(t, credentials.IsAccessError(err)) + }) + t.RunSynced("WrongPassword", func(t *xtest.SyncedTest) { + u, err := url.Parse(os.Getenv("YDB_CONNECTION_STRING")) + require.NoError(t, err) + u.User = url.UserPassword("test", "wrong_password") + test, err := ydb.Open(ctx, u.String()) + require.Error(t, err) + require.Nil(t, test) + require.True(t, credentials.IsAccessError(err)) + }) +} + +func TestWithStaticCredentialsExplicitOption(sourceTest *testing.T) { + t := xtest.MakeSyncedTest(sourceTest) + + ctx := xtest.Context(t) + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ) + require.NoError(t, err) + defer db.Close(ctx) + + err = db.Query().Exec(ctx, `DROP USER IF EXISTS test`) + require.NoError(t, err) + err = db.Query().Exec(ctx, "CREATE USER test PASSWORD 'password'; ALTER GROUP `ADMINS` ADD USER test;") + require.NoError(t, err) + defer func() { + _ = db.Query().Exec(ctx, `DROP USER test`) + }() + + t.RunSynced("HappyWay", func(t *xtest.SyncedTest) { t.RunSynced("WithStaticCredentials", func(t *xtest.SyncedTest) { - db, err := ydb.Open(ctx, + db3, err := ydb.Open(ctx, os.Getenv("YDB_CONNECTION_STRING"), - ydb.WithAccessTokenCredentials( - os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), - ), + ydb.WithStaticCredentials("test", "password"), ) require.NoError(t, err) - defer func() { - _ = db.Close(ctx) - }() - err = db.Query().Exec(ctx, `DROP USER IF EXISTS test`) - require.NoError(t, err) - err = db.Query().Exec(ctx, "CREATE USER test PASSWORD 'password'; ALTER GROUP `ADMINS` ADD USER test;") - require.NoError(t, err) - defer func() { - _ = db.Query().Exec(ctx, `DROP USER test`) - }() - t.RunSynced("UsingConnectionString", func(t *xtest.SyncedTest) { - t.RunSynced("HappyWay", func(t *xtest.SyncedTest) { - u, err := url.Parse(os.Getenv("YDB_CONNECTION_STRING")) - require.NoError(t, err) - u.User = url.UserPassword("test", "password") - t.Log(u.String()) - db, err := ydb.Open(ctx, u.String()) - require.NoError(t, err) - defer func() { - _ = db.Close(ctx) - }() - row, err := db.Query().QueryRow(ctx, `SELECT 1`) - require.NoError(t, err) - var v int - err = row.Scan(&v) - require.NoError(t, err) - tableName := path.Join(db.Name(), t.Name(), "test") - t.RunSynced("CreateTable", func(t *xtest.SyncedTest) { - err := db.Query().Exec(ctx, fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( - id Uint64, - value Utf8, - PRIMARY KEY (id) - )`, "`"+tableName+"`"), - ) - require.NoError(t, err) - }) - t.RunSynced("DescribeTable", func(t *xtest.SyncedTest) { - var d options.Description - err := db.Table().Do(ctx, func(ctx context.Context, s table.Session) error { - d, err = s.DescribeTable(ctx, tableName) - if err != nil { - return err - } - - return nil - }) - require.NoError(t, err) - require.Equal(t, "test", d.Name) - require.Equal(t, 2, len(d.Columns)) - require.Equal(t, "id", d.Columns[0].Name) - require.Equal(t, "value", d.Columns[1].Name) - require.Equal(t, []string{"id"}, d.PrimaryKey) - }) - }) - t.RunSynced("WrongLogin", func(t *xtest.SyncedTest) { - u, err := url.Parse(os.Getenv("YDB_CONNECTION_STRING")) - require.NoError(t, err) - u.User = url.UserPassword("wrong_login", "password") - db, err := ydb.Open(ctx, u.String()) - require.Error(t, err) - require.Nil(t, db) - require.True(t, credentials.IsAccessError(err)) - }) - t.RunSynced("WrongPassword", func(t *xtest.SyncedTest) { - u, err := url.Parse(os.Getenv("YDB_CONNECTION_STRING")) - require.NoError(t, err) - u.User = url.UserPassword("test", "wrong_password") - db, err := ydb.Open(ctx, u.String()) - require.Error(t, err) - require.Nil(t, db) - require.True(t, credentials.IsAccessError(err)) - }) - }) - t.RunSynced("UsingExplicitStaticCredentials", func(t *xtest.SyncedTest) { - t.RunSynced("HappyWay", func(t *xtest.SyncedTest) { - t.RunSynced("WithStaticCredentials", func(t *xtest.SyncedTest) { - db, err := ydb.Open(ctx, - os.Getenv("YDB_CONNECTION_STRING"), - ydb.WithStaticCredentials("test", "password"), - ) - require.NoError(t, err) - defer func() { - _ = db.Close(ctx) - }() - tableName := path.Join(db.Name(), t.Name(), "test") - t.RunSynced("CreateTable", func(t *xtest.SyncedTest) { - err := db.Query().Exec(ctx, fmt.Sprintf(` + defer db3.Close(ctx) + + tableName := path.Join(db3.Name(), t.Name(), "test") + t.RunSynced("CreateTable", func(t *xtest.SyncedTest) { + err := db3.Query().Exec(ctx, fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id Uint64, value Utf8, PRIMARY KEY (id) )`, "`"+tableName+"`"), - ) - require.NoError(t, err) - }) - t.RunSynced("Query", func(t *xtest.SyncedTest) { - row, err := db.Query().QueryRow(ctx, `SELECT 1`) - require.NoError(t, err) - var v int - err = row.Scan(&v) - require.NoError(t, err) - }) - t.RunSynced("DescribeTable", func(t *xtest.SyncedTest) { - var d options.Description - err := db.Table().Do(ctx, func(ctx context.Context, s table.Session) error { - d, err = s.DescribeTable(ctx, tableName) - if err != nil { - return err - } - - return nil - }) - require.NoError(t, err) - require.Equal(t, "test", d.Name) - require.Equal(t, 2, len(d.Columns)) - require.Equal(t, "id", d.Columns[0].Name) - require.Equal(t, "value", d.Columns[1].Name) - require.Equal(t, []string{"id"}, d.PrimaryKey) - }) - }) - t.RunSynced("WithStaticCredentialsLogin+WithStaticCredentialsPassword", - func(t *xtest.SyncedTest) { - db, err := ydb.Open(ctx, - os.Getenv("YDB_CONNECTION_STRING"), - ydb.WithStaticCredentialsLogin("test"), - ydb.WithStaticCredentialsPassword("password"), - ) - require.NoError(t, err) - defer func() { - _ = db.Close(ctx) - }() - tableName := path.Join(db.Name(), t.Name(), "test") - t.RunSynced("CreateTable", func(t *xtest.SyncedTest) { - err := db.Query().Exec(ctx, fmt.Sprintf(` + ) + require.NoError(t, err) + }) + t.RunSynced("Query", func(t *xtest.SyncedTest) { + row, err := db3.Query().QueryRow(ctx, `SELECT 1`) + require.NoError(t, err) + var v int + err = row.Scan(&v) + require.NoError(t, err) + }) + t.RunSynced("DescribeTable", func(t *xtest.SyncedTest) { + var d options.Description + err := db3.Table().Do(ctx, func(ctx context.Context, s table.Session) error { + d, err = s.DescribeTable(ctx, tableName) + if err != nil { + return err + } + + return nil + }) + require.NoError(t, err) + require.Equal(t, "test", d.Name) + require.Equal(t, 2, len(d.Columns)) + require.Equal(t, "id", d.Columns[0].Name) + require.Equal(t, "value", d.Columns[1].Name) + require.Equal(t, []string{"id"}, d.PrimaryKey) + }) + }) + t.RunSynced("WithStaticCredentialsLogin+WithStaticCredentialsPassword", + func(t *xtest.SyncedTest) { + db3, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithStaticCredentialsLogin("test"), + ydb.WithStaticCredentialsPassword("password"), + ) + require.NoError(t, err) + defer db3.Close(ctx) + + tableName := path.Join(db3.Name(), t.Name(), "test") + t.RunSynced("CreateTable", func(t *xtest.SyncedTest) { + err := db3.Query().Exec(ctx, fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id Uint64, value Utf8, PRIMARY KEY (id) )`, "`"+tableName+"`"), - ) - require.NoError(t, err) - }) - t.RunSynced("Query", func(t *xtest.SyncedTest) { - row, err := db.Query().QueryRow(ctx, `SELECT 1`) - require.NoError(t, err) - var v int - err = row.Scan(&v) - require.NoError(t, err) - }) - t.RunSynced("DescribeTable", func(t *xtest.SyncedTest) { - var d options.Description - err := db.Table().Do(ctx, func(ctx context.Context, s table.Session) error { - d, err = s.DescribeTable(ctx, tableName) - if err != nil { - return err - } - - return nil - }) - require.NoError(t, err) - require.Equal(t, "test", d.Name) - require.Equal(t, 2, len(d.Columns)) - require.Equal(t, "id", d.Columns[0].Name) - require.Equal(t, "value", d.Columns[1].Name) - require.Equal(t, []string{"id"}, d.PrimaryKey) - }) - }) - }) - t.RunSynced("WrongLogin", func(t *xtest.SyncedTest) { - db, err := ydb.Open(ctx, - os.Getenv("YDB_CONNECTION_STRING"), - ydb.WithStaticCredentials("wrong_user", "password"), ) - require.Error(t, err) - require.Nil(t, db) - require.True(t, credentials.IsAccessError(err)) + require.NoError(t, err) }) - t.RunSynced("WrongPassword", func(t *xtest.SyncedTest) { - db, err := ydb.Open(ctx, - os.Getenv("YDB_CONNECTION_STRING"), - ydb.WithStaticCredentials("test", "wrong_password"), - ) - require.Error(t, err) - require.Nil(t, db) - require.True(t, credentials.IsAccessError(err)) + t.RunSynced("Query", func(t *xtest.SyncedTest) { + row, err := db3.Query().QueryRow(ctx, `SELECT 1`) + require.NoError(t, err) + var v int + err = row.Scan(&v) + require.NoError(t, err) + }) + t.RunSynced("DescribeTable", func(t *xtest.SyncedTest) { + var d options.Description + err := db3.Table().Do(ctx, func(ctx context.Context, s table.Session) error { + d, err = s.DescribeTable(ctx, tableName) + if err != nil { + return err + } + + return nil + }) + require.NoError(t, err) + require.Equal(t, "test", d.Name) + require.Equal(t, 2, len(d.Columns)) + require.Equal(t, "id", d.Columns[0].Name) + require.Equal(t, "value", d.Columns[1].Name) + require.Equal(t, []string{"id"}, d.PrimaryKey) }) }) - }) - t.RunSynced("With", func(t *xtest.SyncedTest) { - t.Run("WithSharedBalancer", func(t *testing.T) { - child, err := db.With(ctx, ydb.WithSharedBalancer(db)) - require.NoError(t, err) - result, err := child.Scripting().Execute(ctx, `SELECT 1`, nil) - require.NoError(t, err) - require.NoError(t, result.NextResultSetErr(ctx)) - require.True(t, result.NextRow()) - var value int32 - err = result.Scan(indexed.Required(&value)) - require.NoError(t, err) - require.EqualValues(t, 1, value) - err = child.Close(ctx) + }) + t.RunSynced("WrongLogin", func(t *xtest.SyncedTest) { + db3, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithStaticCredentials("wrong_user", "password"), + ) + require.Error(t, err) + require.Nil(t, db3) + require.True(t, credentials.IsAccessError(err)) + }) + t.RunSynced("WrongPassword", func(t *xtest.SyncedTest) { + db3, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithStaticCredentials("test", "wrong_password"), + ) + require.Error(t, err) + require.Nil(t, db3) + require.True(t, credentials.IsAccessError(err)) + }) +} + +func TestWithSharedBalancer(sourceTest *testing.T) { + t := xtest.MakeSyncedTest(sourceTest) + + ctx := xtest.Context(t) + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ) + require.NoError(t, err) + defer db.Close(ctx) + + child, err := db.With(ctx, + ydb.WithSharedBalancer(db), + ) + require.NoError(t, err) + result, err := child.Scripting().Execute(ctx, `SELECT 1`, nil) + require.NoError(t, err) + require.NoError(t, result.NextResultSetErr(ctx)) + require.True(t, result.NextRow()) + var value int32 + err = result.Scan(indexed.Required(&value)) + require.NoError(t, err) + require.EqualValues(t, 1, value) + err = child.Close(ctx) + require.NoError(t, err) +} + +func TestExportToS3(sourceTest *testing.T) { + t := xtest.MakeSyncedTest(sourceTest) + + ctx := xtest.Context(t) + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ) + require.NoError(t, err) + defer db.Close(ctx) + + if err := retry.Retry(ctx, func(ctx context.Context) (err error) { + exportClient := Ydb_Export_V1.NewExportServiceClient(ydb.GRPCConn(db)) + response, err := exportClient.ExportToS3( + ctx, + &Ydb_Export.ExportToS3Request{ + OperationParams: &Ydb_Operations.OperationParams{ + OperationTimeout: durationpb.New(time.Second), + CancelAfter: durationpb.New(time.Second), + }, + Settings: &Ydb_Export.ExportToS3Settings{}, + }, + ) + if err != nil { + return err + } + if response.GetOperation().GetStatus() != Ydb.StatusIds_BAD_REQUEST { + return fmt.Errorf( + "operation must be BAD_REQUEST: %s", + response.GetOperation().GetStatus().String(), + ) + } + return nil + }, retry.WithIdempotent(true)); err != nil { + t.Fatalf("check export failed: %v", err) + } +} + +func TestScriptingStreamExecuteYql(sourceTest *testing.T) { + t := xtest.MakeSyncedTest(sourceTest) + + ctx := xtest.Context(t) + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ) + require.NoError(t, err) + defer db.Close(ctx) + + const sumColumn = "sum" + + for _, tt := range []struct { + name string + db *ydb.Driver + }{ + { + name: "parent", + db: db, + }, + { + name: "child", + db: func() *ydb.Driver { + child, err := db.With(ctx, ydb.WithDialTimeout(time.Second*5)) require.NoError(t, err) - }) - }) - t.Run("discovery.WhoAmI", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - discoveryClient := Ydb_Discovery_V1.NewDiscoveryServiceClient(ydb.GRPCConn(db)) - response, err := discoveryClient.WhoAmI( - ctx, - &Ydb_Discovery.WhoAmIRequest{IncludeGroups: true}, - ) - if err != nil { - return err - } - var result Ydb_Discovery.WhoAmIResult - err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) - if err != nil { - return - } - return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Execute failed: %v", err) - } - }) - t.Run("scripting.ExecuteYql", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db)) - response, err := scriptingClient.ExecuteYql( - ctx, - &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, - ) - if err != nil { - return err - } - var result Ydb_Scripting.ExecuteYqlResult - err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) - if err != nil { - return - } - if len(result.GetResultSets()) != 1 { - return fmt.Errorf( - "unexpected result sets count: %d", - len(result.GetResultSets()), - ) - } - if len(result.GetResultSets()[0].GetColumns()) != 1 { - return fmt.Errorf( - "unexpected colums count: %d", - len(result.GetResultSets()[0].GetColumns()), - ) - } - if result.GetResultSets()[0].GetColumns()[0].GetName() != sumColumn { - return fmt.Errorf( - "unexpected colum name: %s", - result.GetResultSets()[0].GetColumns()[0].GetName(), - ) - } - if len(result.GetResultSets()[0].GetRows()) != 1 { - return fmt.Errorf( - "unexpected rows count: %d", - len(result.GetResultSets()[0].GetRows()), - ) - } - if result.GetResultSets()[0].GetRows()[0].GetItems()[0].GetInt32Value() != 101 { - return fmt.Errorf( - "unexpected result of select: %d", - result.GetResultSets()[0].GetRows()[0].GetInt64Value(), - ) - } - return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Execute failed: %v", err) - } - }) - t.Run("scripting.StreamExecuteYql", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db)) + + return child + }(), + }, + } { + t.Run(tt.name, func(t *testing.T) { + err = retry.Retry(ctx, func(ctx context.Context) (err error) { + scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(tt.db)) client, err := scriptingClient.StreamExecuteYql( ctx, &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, @@ -492,92 +530,152 @@ func TestDriver(sourceTest *testing.T) { ) } return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Stream execute failed: %v", err) - } + }, retry.WithIdempotent(true)) + require.NoError(t, err) }) - t.Run("with.scripting.StreamExecuteYql", func(t *testing.T) { - var childDB *ydb.Driver - childDB, err = db.With( - ctx, - ydb.WithDialTimeout(time.Second*5), - ) - if err != nil { - t.Fatalf("failed to open sub-connection: %v", err) - } - defer func() { - _ = childDB.Close(ctx) - }() - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(childDB)) - client, err := scriptingClient.StreamExecuteYql( + } +} + +func TestScriptingExecuteYql(sourceTest *testing.T) { + t := xtest.MakeSyncedTest(sourceTest) + + ctx := xtest.Context(t) + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ) + require.NoError(t, err) + defer db.Close(ctx) + + const sumColumn = "sum" + + for _, tt := range []struct { + name string + db *ydb.Driver + }{ + { + name: "parent", + db: db, + }, + { + name: "child", + db: func() *ydb.Driver { + child, err := db.With(ctx, ydb.WithDialTimeout(time.Second*5)) + require.NoError(t, err) + + return child + }(), + }, + } { + t.Run(tt.name, func(t *testing.T) { + err = retry.Retry(ctx, func(ctx context.Context) (err error) { + scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(tt.db)) + response, err := scriptingClient.ExecuteYql( ctx, &Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"}, ) if err != nil { return err } - response, err := client.Recv() + var result Ydb_Scripting.ExecuteYqlResult + err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) if err != nil { - return err + return } - if len(response.GetResult().GetResultSet().GetColumns()) != 1 { + if len(result.GetResultSets()) != 1 { + return fmt.Errorf( + "unexpected result sets count: %d", + len(result.GetResultSets()), + ) + } + if len(result.GetResultSets()[0].GetColumns()) != 1 { return fmt.Errorf( "unexpected colums count: %d", - len(response.GetResult().GetResultSet().GetColumns()), + len(result.GetResultSets()[0].GetColumns()), ) } - if response.GetResult().GetResultSet().GetColumns()[0].GetName() != sumColumn { + if result.GetResultSets()[0].GetColumns()[0].GetName() != sumColumn { return fmt.Errorf( "unexpected colum name: %s", - response.GetResult().GetResultSet().GetColumns()[0].GetName(), + result.GetResultSets()[0].GetColumns()[0].GetName(), ) } - if len(response.GetResult().GetResultSet().GetRows()) != 1 { + if len(result.GetResultSets()[0].GetRows()) != 1 { return fmt.Errorf( "unexpected rows count: %d", - len(response.GetResult().GetResultSet().GetRows()), + len(result.GetResultSets()[0].GetRows()), ) } - if response.GetResult().GetResultSet().GetRows()[0].GetItems()[0].GetInt32Value() != 101 { + if result.GetResultSets()[0].GetRows()[0].GetItems()[0].GetInt32Value() != 101 { return fmt.Errorf( "unexpected result of select: %d", - response.GetResult().GetResultSet().GetRows()[0].GetInt64Value(), + result.GetResultSets()[0].GetRows()[0].GetInt64Value(), ) } return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("Stream execute failed: %v", err) - } + }, retry.WithIdempotent(true)) + require.NoError(t, err) }) - t.Run("export.ExportToS3", func(t *testing.T) { - if err = retry.Retry(ctx, func(ctx context.Context) (err error) { - exportClient := Ydb_Export_V1.NewExportServiceClient(ydb.GRPCConn(db)) - response, err := exportClient.ExportToS3( + } +} + +func TestDiscoveryWhoAmI(sourceTest *testing.T) { + t := xtest.MakeSyncedTest(sourceTest) + + ctx := xtest.Context(t) + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials( + os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"), + ), + ) + require.NoError(t, err) + defer db.Close(ctx) + + const sumColumn = "sum" + + for _, tt := range []struct { + name string + db *ydb.Driver + }{ + { + name: "parent", + db: db, + }, + { + name: "child", + db: func() *ydb.Driver { + child, err := db.With(ctx, ydb.WithDialTimeout(time.Second*5)) + require.NoError(t, err) + + return child + }(), + }, + } { + t.Run(tt.name, func(t *testing.T) { + err = retry.Retry(ctx, func(ctx context.Context) (err error) { + discoveryClient := Ydb_Discovery_V1.NewDiscoveryServiceClient(ydb.GRPCConn(tt.db)) + response, err := discoveryClient.WhoAmI( ctx, - &Ydb_Export.ExportToS3Request{ - OperationParams: &Ydb_Operations.OperationParams{ - OperationTimeout: durationpb.New(time.Second), - CancelAfter: durationpb.New(time.Second), - }, - Settings: &Ydb_Export.ExportToS3Settings{}, - }, + &Ydb_Discovery.WhoAmIRequest{IncludeGroups: true}, ) if err != nil { return err } - if response.GetOperation().GetStatus() != Ydb.StatusIds_BAD_REQUEST { - return fmt.Errorf( - "operation must be BAD_REQUEST: %s", - response.GetOperation().GetStatus().String(), - ) + var result Ydb_Discovery.WhoAmIResult + err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result) + if err != nil { + return } return nil - }, retry.WithIdempotent(true)); err != nil { - t.Fatalf("check export failed: %v", err) - } + }, retry.WithIdempotent(true)) + require.NoError(t, err) }) - }) + } } func TestZeroDialTimeout(t *testing.T) { diff --git a/tests/integration/helpers_test.go b/tests/integration/helpers_test.go index f8c5b903e..e82956a38 100644 --- a/tests/integration/helpers_test.go +++ b/tests/integration/helpers_test.go @@ -10,7 +10,6 @@ import ( "fmt" "os" "path" - "runtime" "strings" "testing" "text/template" @@ -48,9 +47,7 @@ func newScope(t *testing.T) *scopeT { at := require.New(st) fEnv := fixenv.New(st) ctx, ctxCancel := context.WithCancel(context.Background()) - st.Cleanup(func() { - ctxCancel() - }) + st.Cleanup(ctxCancel) res := &scopeT{ Ctx: ctx, Env: fEnv, @@ -156,10 +153,18 @@ func (scope *scopeT) SQLDriver(opts ...ydb.ConnectorOption) *sql.DB { scope.Logf("Ping db") err = db.PingContext(scope.Ctx) if err != nil { + _ = db.Close() + return nil, err } - return fixenv.NewGenericResult(db), nil + clean := func() { + if db != nil { + scope.Require.NoError(db.Close()) + } + } + + return fixenv.NewGenericResultWithCleanup(db, clean), err } return fixenv.CacheResult(scope.Env, f) } @@ -469,16 +474,3 @@ func driverEngine(db *sql.DB) (engine xsql.Engine) { return engine } - -func simpleDetectGoroutineLeak(t *testing.T) { - // 1) testing.go => main.main() - // 2) current test - const expectedGoroutinesCount = 2 - if num := runtime.NumGoroutine(); num > expectedGoroutinesCount { - bb := make([]byte, 2<<32) - if n := runtime.Stack(bb, true); n < len(bb) { - bb = bb[:n] - } - t.Error(fmt.Sprintf("unexpected goroutines:\n%s\n", string(bb[runtime.Stack(bb, false)+1:]))) - } -} diff --git a/tests/integration/monitoring_test.go b/tests/integration/monitoring_test.go index 5220c1b25..a15d35a3c 100644 --- a/tests/integration/monitoring_test.go +++ b/tests/integration/monitoring_test.go @@ -24,12 +24,8 @@ func TestMonitoring(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("close failed: %+v", e) - } - }() + defer db.Close(ctx) + t.Run("monitoring.SelfCheck", func(t *testing.T) { err = retry.Retry(ctx, func(ctx context.Context) (err error) { client := Ydb_Monitoring_V1.NewMonitoringServiceClient(ydb.GRPCConn(db)) diff --git a/tests/integration/operation_test.go b/tests/integration/operation_test.go index bc09d81a1..03c5dac02 100644 --- a/tests/integration/operation_test.go +++ b/tests/integration/operation_test.go @@ -35,6 +35,8 @@ func TestOperationList(t *testing.T) { ), ) require.NoError(t, err) + defer db.Close(ctx) + operations, err := db.Operation().ListBuildIndex(ctx) require.NoError(t, err) diff --git a/tests/integration/query_execute_script_test.go b/tests/integration/query_execute_script_test.go index 913767364..b616b9273 100644 --- a/tests/integration/query_execute_script_test.go +++ b/tests/integration/query_execute_script_test.go @@ -41,10 +41,7 @@ func TestQueryExecuteScript(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func(db *ydb.Driver) { - // cleanup - _ = db.Close(ctx) - }(db) + defer db.Close(ctx) err = db.Query().Exec(ctx, "CREATE TABLE IF NOT EXISTS `"+path.Join(db.Name(), folder, tableName)+"` (val Int64, PRIMARY KEY (val))", diff --git a/tests/integration/query_execute_test.go b/tests/integration/query_execute_test.go index 113da45fc..8bacd1993 100644 --- a/tests/integration/query_execute_test.go +++ b/tests/integration/query_execute_test.go @@ -47,6 +47,8 @@ func TestQueryExecute(t *testing.T) { ), ) require.NoError(t, err) + defer db.Close(ctx) + t.Run("Query", func(t *testing.T) { var ( p1 string @@ -301,6 +303,7 @@ func TestIssue1456TooManyUnknownTransactions(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) const ( tableSize = 10000 @@ -428,6 +431,7 @@ SELECT * FROM AS_TABLE($arg); if err != nil { return err } + defer rs.Close(ctx) for i := 0; i < targetCount; i++ { row, err := rs.NextRow(ctx) @@ -454,6 +458,7 @@ SELECT * FROM AS_TABLE($arg); if err != nil { return err } + defer rs.Close(ctx) _, err = rs.NextRow(ctx) if err != nil { diff --git a/tests/integration/query_multi_result_sets_test.go b/tests/integration/query_multi_result_sets_test.go index f09857b9d..6d0253c52 100644 --- a/tests/integration/query_multi_result_sets_test.go +++ b/tests/integration/query_multi_result_sets_test.go @@ -24,6 +24,8 @@ func TestQueryMultiResultSets(t *testing.T) { if err != nil { return fmt.Errorf("can't get result: %w", err) } + defer res.Close(ctx) + set, err := res.NextResultSet(ctx) if err != nil { return fmt.Errorf("set 0: get next result set: %w", err) diff --git a/tests/integration/query_range_test.go b/tests/integration/query_range_test.go index 03505e30e..8c4b0782d 100644 --- a/tests/integration/query_range_test.go +++ b/tests/integration/query_range_test.go @@ -28,6 +28,8 @@ func TestQueryRange(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + t.Run("Execute", func(t *testing.T) { listItems := make([]value.Value, 1000) for i := range make([]struct{}, 1000) { diff --git a/tests/integration/query_read_row_test.go b/tests/integration/query_read_row_test.go index 94793a874..050f4326e 100644 --- a/tests/integration/query_read_row_test.go +++ b/tests/integration/query_read_row_test.go @@ -38,6 +38,7 @@ func TestQueryReadRow(t *testing.T) { ), ) require.NoError(t, err) + defer db.Close(ctx) row, err := db.Query().QueryRow(ctx, ` DECLARE $p1 AS Text; diff --git a/tests/integration/ratelimiter_test.go b/tests/integration/ratelimiter_test.go index b5929098b..dbc332f57 100644 --- a/tests/integration/ratelimiter_test.go +++ b/tests/integration/ratelimiter_test.go @@ -45,12 +45,8 @@ func TestRatelimiter(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("db close failed: %+v", e) - } - }() + defer db.Close(ctx) + // drop node err = db.Coordination().DropNode(ctx, testCoordinationNodePath) if err != nil { diff --git a/tests/integration/register_dsn_parser_test.go b/tests/integration/register_dsn_parser_test.go index 080a1afa4..47277d987 100644 --- a/tests/integration/register_dsn_parser_test.go +++ b/tests/integration/register_dsn_parser_test.go @@ -30,9 +30,7 @@ func TestRegisterDsnParser(t *testing.T) { db, err := ydb.Open(context.Background(), os.Getenv("YDB_CONNECTION_STRING")) require.NoError(t, err) require.True(t, visited) - defer func() { - _ = db.Close(context.Background()) - }() + defer db.Close(context.Background()) }) t.Run("database/sql", func(t *testing.T) { var visited bool diff --git a/tests/integration/retry_budget_test.go b/tests/integration/retry_budget_test.go index 92fd2531e..209a10fbb 100644 --- a/tests/integration/retry_budget_test.go +++ b/tests/integration/retry_budget_test.go @@ -39,10 +39,7 @@ func TestRetryBudget(t *testing.T) { ) require.NoError(t, err) - defer func() { - // cleanup - _ = nativeDriver.Close(ctx) - }() + defer nativeDriver.Close(ctx) c, err := ydb.Connector(nativeDriver) require.NoError(t, err) @@ -53,10 +50,7 @@ func TestRetryBudget(t *testing.T) { }() db := sql.OpenDB(c) - defer func() { - // cleanup - _ = db.Close() - }() + defer db.Close() retryBudget := noQuota{} diff --git a/tests/integration/scripting_test.go b/tests/integration/scripting_test.go index 8b2bdf4ad..7eb301fec 100644 --- a/tests/integration/scripting_test.go +++ b/tests/integration/scripting_test.go @@ -46,12 +46,8 @@ func TestScripting(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("db close failed: %+v", e) - } - }() + defer db.Close(ctx) + // Execute if err = retry.Retry(ctx, func(ctx context.Context) (err error) { res, err := db.Scripting().Execute( diff --git a/tests/integration/static_credentials_test.go b/tests/integration/static_credentials_test.go index 2013eace7..b23a70cd5 100644 --- a/tests/integration/static_credentials_test.go +++ b/tests/integration/static_credentials_test.go @@ -65,12 +65,8 @@ func TestStaticCredentials(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("close failed: %+v", e) - } - }() + defer db.Close(ctx) + _, err = db.Discovery().WhoAmI(ctx) if err != nil { t.Fatal(err) diff --git a/tests/integration/sugar_result_test.go b/tests/integration/sugar_result_test.go index 3564eb704..1637fcb19 100644 --- a/tests/integration/sugar_result_test.go +++ b/tests/integration/sugar_result_test.go @@ -28,6 +28,8 @@ func TestSugarResult(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + t.Run("Scan", func(t *testing.T) { t.Run("Table", func(t *testing.T) { var ( diff --git a/tests/integration/table_bulk_upsert_test.go b/tests/integration/table_bulk_upsert_test.go index fc0906d40..f885b01e5 100644 --- a/tests/integration/table_bulk_upsert_test.go +++ b/tests/integration/table_bulk_upsert_test.go @@ -201,6 +201,8 @@ func assertIdValueImpl(ctx context.Context, t *testing.T, tableName string, id i // ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { res, err := tx.Execute(ctx, fmt.Sprintf("SELECT val FROM `%s` WHERE id = %d", tableName, id), nil) if err != nil { diff --git a/tests/integration/table_containers_test.go b/tests/integration/table_containers_test.go index dad0e986d..a9981281d 100644 --- a/tests/integration/table_containers_test.go +++ b/tests/integration/table_containers_test.go @@ -25,6 +25,8 @@ func TestContainers(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { res, err := tx.Execute(ctx, ` SELECT diff --git a/tests/integration/table_create_table_description_test.go b/tests/integration/table_create_table_description_test.go index b7b6f5b3d..018d0d91a 100644 --- a/tests/integration/table_create_table_description_test.go +++ b/tests/integration/table_create_table_description_test.go @@ -35,9 +35,8 @@ func TestCreateTableDescription(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - _ = db.Close(ctx) - }() + defer db.Close(ctx) + for _, tt := range []struct { opts []options.CreateTableOption description options.Description diff --git a/tests/integration/table_create_table_partitions_test.go b/tests/integration/table_create_table_partitions_test.go index 16c486e4c..a3c9e82a6 100644 --- a/tests/integration/table_create_table_partitions_test.go +++ b/tests/integration/table_create_table_partitions_test.go @@ -34,9 +34,7 @@ func TestTableCreateTablePartitions(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - _ = db.Close(ctx) - }() + defer db.Close(ctx) t.Run("uniform partitions", func(t *testing.T) { err := db.Table().Do(ctx, diff --git a/tests/integration/table_data_query_issue_row_col_test.go b/tests/integration/table_data_query_issue_row_col_test.go index f8dcb682b..8ab73675b 100644 --- a/tests/integration/table_data_query_issue_row_col_test.go +++ b/tests/integration/table_data_query_issue_row_col_test.go @@ -28,6 +28,7 @@ func TestDataQueryIssueRowCol(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) exists, err := sugar.IsTableExists(ctx, db.Scheme(), path.Join(db.Name(), "users")) require.NoError(t, err) if exists { diff --git a/tests/integration/table_data_query_with_compression_test.go b/tests/integration/table_data_query_with_compression_test.go index 927dbe303..bc846e1d1 100644 --- a/tests/integration/table_data_query_with_compression_test.go +++ b/tests/integration/table_data_query_with_compression_test.go @@ -29,6 +29,8 @@ func TestDataQueryWithCompression(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { res, err := tx.Execute(ctx, `SELECT 1 as abc, 2 as def;`, nil, options.WithCallOptions( grpc.UseCompressor(gzip.Name), diff --git a/tests/integration/table_interval_from_duration_test.go b/tests/integration/table_interval_from_duration_test.go index e636158e4..aa021325c 100644 --- a/tests/integration/table_interval_from_duration_test.go +++ b/tests/integration/table_interval_from_duration_test.go @@ -27,6 +27,7 @@ func TestIssue259IntervalFromDuration(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) t.Run("Check about interval work with microseconds", func(t *testing.T) { err := db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) error { diff --git a/tests/integration/table_long_stream_test.go b/tests/integration/table_long_stream_test.go index 4f668d9c8..43df1ce6c 100644 --- a/tests/integration/table_long_stream_test.go +++ b/tests/integration/table_long_stream_test.go @@ -50,10 +50,7 @@ func TestLongStream(sourceTest *testing.T) { if err != nil { t.Fatal(err) } - defer func(db *ydb.Driver) { - // cleanup - _ = db.Close(ctx) - }(db) + defer db.Close(ctx) t.Run("creating", func(t *testing.T) { t.Run("stream", func(t *testing.T) { diff --git a/tests/integration/table_multiple_result_sets_test.go b/tests/integration/table_multiple_result_sets_test.go index 43992f9cd..e66646807 100644 --- a/tests/integration/table_multiple_result_sets_test.go +++ b/tests/integration/table_multiple_result_sets_test.go @@ -48,11 +48,7 @@ func TestTableMultipleResultSets(sourceTest *testing.T) { ), ) require.NoError(t, err) - - defer func() { - err = db.Close(ctx) - require.NoError(t, err) - }() + defer db.Close(ctx) t.Run("create", func(t *testing.T) { t.Run("table", func(t *testing.T) { diff --git a/tests/integration/table_null_type_test.go b/tests/integration/table_null_type_test.go index 60b49c864..d3d18d1f0 100644 --- a/tests/integration/table_null_type_test.go +++ b/tests/integration/table_null_type_test.go @@ -27,6 +27,8 @@ func TestNullType(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { res, err := tx.Execute(ctx, `SELECT NULL AS reschedule_due;`, nil) if err != nil { diff --git a/tests/integration/table_scan_error_test.go b/tests/integration/table_scan_error_test.go index ec18e55c6..d075ff896 100644 --- a/tests/integration/table_scan_error_test.go +++ b/tests/integration/table_scan_error_test.go @@ -27,6 +27,8 @@ func TestIssue415ScanError(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { res, err := tx.Execute(ctx, `SELECT 1 as abc, 2 as def;`, nil) if err != nil { @@ -67,6 +69,8 @@ func TestIssue847ScanError(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + err = db.Table().Do(ctx, func(ctx context.Context, s table.Session) (err error) { res, err := s.StreamExecuteScanQuery(ctx, `SELICT 1;`, nil) if err != nil { diff --git a/tests/integration/table_scan_query_with_compression_test.go b/tests/integration/table_scan_query_with_compression_test.go index 7c805f388..4912e667a 100644 --- a/tests/integration/table_scan_query_with_compression_test.go +++ b/tests/integration/table_scan_query_with_compression_test.go @@ -29,6 +29,8 @@ func TestScanQueryWithCompression(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) + defer db.Close(ctx) + err = db.Table().Do(ctx, func(ctx context.Context, s table.Session) (err error) { res, err := s.StreamExecuteScanQuery(ctx, `SELECT 1 as abc, 2 as def;`, nil, options.WithCallOptions( grpc.UseCompressor(gzip.Name), diff --git a/tests/integration/table_split_ranges_and_read_test.go b/tests/integration/table_split_ranges_and_read_test.go index f1bb4f4ba..d789486ca 100644 --- a/tests/integration/table_split_ranges_and_read_test.go +++ b/tests/integration/table_split_ranges_and_read_test.go @@ -41,10 +41,7 @@ func TestSplitRangesAndRead(t *testing.T) { if err != nil { t.Fatal(err) } - defer func(db *ydb.Driver) { - // cleanup - _ = db.Close(ctx) - }(db) + defer db.Close(ctx) t.Run("creating table", func(t *testing.T) { if err = db.Table().Do(ctx, diff --git a/tests/integration/table_truncated_err_test.go b/tests/integration/table_truncated_err_test.go index 26a64ed79..9b84b3170 100644 --- a/tests/integration/table_truncated_err_test.go +++ b/tests/integration/table_truncated_err_test.go @@ -223,6 +223,8 @@ func TestIssue798TruncatedError(t *testing.T) { scope.Require.NoError(err) db = sql.OpenDB(ydb.MustConnector(driver)) + defer db.Close() + err = retry.Do(ctx, db, func(ctx context.Context, cc *sql.Conn) error { rows, err := cc.QueryContext(ctx, fmt.Sprintf("SELECT * FROM `%s`;", tablePath)) if err != nil { @@ -436,6 +438,8 @@ func TestIssue798NoTruncatedErrorOverQueryService(t *testing.T) { scope.Require.NoError(err) db = sql.OpenDB(ydb.MustConnector(driver)) + defer db.Close() + err = retry.Do(ctx, db, func(ctx context.Context, cc *sql.Conn) error { rows, err := cc.QueryContext(ctx, fmt.Sprintf("SELECT * FROM `%s`;", tablePath)) if err != nil { diff --git a/tests/integration/table_tx_lazy_test.go b/tests/integration/table_tx_lazy_test.go index 1aaf549a4..8dfecb909 100644 --- a/tests/integration/table_tx_lazy_test.go +++ b/tests/integration/table_tx_lazy_test.go @@ -29,6 +29,7 @@ func TestTableTxLazy(t *testing.T) { }) require.NotNil(t, db) + defer db.Close(ctx) t.Run("tx", func(t *testing.T) { t.Run("lazy", func(t *testing.T) { diff --git a/tests/integration/table_tz_timestamp_test.go b/tests/integration/table_tz_timestamp_test.go index 8d4ab24bc..63fd292cf 100644 --- a/tests/integration/table_tz_timestamp_test.go +++ b/tests/integration/table_tz_timestamp_test.go @@ -22,9 +22,8 @@ func TestTzTimestamp(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - _ = db.Close(ctx) - }() + defer db.Close(ctx) + err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { microseconds := int64(1680021427000000) res, err := tx.Execute(ctx, fmt.Sprintf(`SELECT CAST(CAST(%d AS Timestamp) AS TzTimestamp);`, microseconds), nil) diff --git a/tests/integration/table_unexpected_null_while_parse_nil_json_document_value_test.go b/tests/integration/table_unexpected_null_while_parse_nil_json_document_value_test.go index d62987e56..53f385da9 100644 --- a/tests/integration/table_unexpected_null_while_parse_nil_json_document_value_test.go +++ b/tests/integration/table_unexpected_null_while_parse_nil_json_document_value_test.go @@ -32,10 +32,8 @@ func TestIssue229UnexpectedNullWhileParseNilJsonDocumentValue(t *testing.T) { ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), ) require.NoError(t, err) - defer func(db *ydb.Driver) { - // cleanup - _ = db.Close(ctx) - }(db) + defer db.Close(ctx) + var val issue229Struct err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { res, err := tx.Execute(ctx, `SELECT Nothing(JsonDocument?) AS r`, nil) diff --git a/tests/integration/with_trace_retry_test.go b/tests/integration/with_trace_retry_test.go index 1cb5b6564..1f0dc94bb 100644 --- a/tests/integration/with_trace_retry_test.go +++ b/tests/integration/with_trace_retry_test.go @@ -70,6 +70,8 @@ func TestWithTraceRetry(t *testing.T) { ) db = sql.OpenDB(ydb.MustConnector(nativeDb)) ) + defer db.Close() + require.NoError(t, retry.Do(ctx, db, func(ctx context.Context, cc *sql.Conn) error { return nil