diff --git a/pkg/crosscluster/logical/BUILD.bazel b/pkg/crosscluster/logical/BUILD.bazel index 97ddff091132..5c00af24948c 100644 --- a/pkg/crosscluster/logical/BUILD.bazel +++ b/pkg/crosscluster/logical/BUILD.bazel @@ -86,6 +86,7 @@ go_library( "//pkg/sql/sem/tree/treecmp", "//pkg/sql/sessiondata", "//pkg/sql/sessiondatapb", + "//pkg/sql/sessionmutator", "//pkg/sql/sqlclustersettings", "//pkg/sql/stats", "//pkg/sql/syntheticprivilege", diff --git a/pkg/crosscluster/logical/batch_handler_test.go b/pkg/crosscluster/logical/batch_handler_test.go index b38455a74c01..6e478182c98c 100644 --- a/pkg/crosscluster/logical/batch_handler_test.go +++ b/pkg/crosscluster/logical/batch_handler_test.go @@ -244,6 +244,7 @@ func testBatchHandlerExhaustive(t *testing.T, factory batchHandlerFactory) { } handler, desc := factory(t, s, "test_table") + defer handler.Close(ctx) defer handler.ReleaseLeases(ctx) // TODO(jeffswenson): test the other handler types. diff --git a/pkg/crosscluster/logical/logical_replication_job_test.go b/pkg/crosscluster/logical/logical_replication_job_test.go index 324db9aa9ab1..44683d903831 100644 --- a/pkg/crosscluster/logical/logical_replication_job_test.go +++ b/pkg/crosscluster/logical/logical_replication_job_test.go @@ -895,6 +895,9 @@ func TestRandomTables(t *testing.T) { tc, s, runnerA, runnerB := setupLogicalTestServer(t, ctx, testClusterBaseClusterArgs, 1) defer tc.Stopper().Stop(ctx) + // TODO(#148303): Remove this once the crud writer supports tables with array primary keys. + runnerA.Exec(t, "SET CLUSTER SETTING logical_replication.consumer.immediate_mode_writer = 'legacy-kv'") + sqlA := s.SQLConn(t, serverutils.DBName("a")) var tableName, streamStartStmt string diff --git a/pkg/crosscluster/logical/logical_replication_writer_processor.go b/pkg/crosscluster/logical/logical_replication_writer_processor.go index 3e8fc1e2faff..e3cf953bbd3a 100644 --- a/pkg/crosscluster/logical/logical_replication_writer_processor.go +++ b/pkg/crosscluster/logical/logical_replication_writer_processor.go @@ -144,9 +144,10 @@ type logicalReplicationWriterProcessor struct { purgatory purgatory - seenKeys map[uint64]int64 - dupeCount int64 - seenEvery log.EveryN + seenKeys map[uint64]int64 + dupeCount int64 + seenEvery log.EveryN + retryEvery log.EveryN } var ( @@ -224,9 +225,10 @@ func newLogicalReplicationWriterProcessor( StreamID: streampb.StreamID(spec.StreamID), ProcessorID: processorID, }, - dlqClient: InitDeadLetterQueueClient(dlqDbExec, destTableBySrcID), - metrics: flowCtx.Cfg.JobRegistry.MetricsStruct().JobSpecificMetrics[jobspb.TypeLogicalReplication].(*Metrics), - seenEvery: log.Every(1 * time.Minute), + dlqClient: InitDeadLetterQueueClient(dlqDbExec, destTableBySrcID), + metrics: flowCtx.Cfg.JobRegistry.MetricsStruct().JobSpecificMetrics[jobspb.TypeLogicalReplication].(*Metrics), + seenEvery: log.Every(1 * time.Minute), + retryEvery: log.Every(1 * time.Minute), } lrw.purgatory = purgatory{ deadline: func() time.Duration { return retryQueueAgeLimit.Get(&flowCtx.Cfg.Settings.SV) }, @@ -1039,6 +1041,9 @@ func (lrw *logicalReplicationWriterProcessor) flushChunk( } stats.processed.dlq++ } else { + if lrw.retryEvery.ShouldLog() { + log.Dev.Warningf(ctx, "retrying failed apply: %+v", err) + } stats.notProcessed.count++ stats.notProcessed.bytes += int64(batch[i].Size()) } diff --git a/pkg/crosscluster/logical/replication_statements.go b/pkg/crosscluster/logical/replication_statements.go index 3f287b2295a9..d326ff6e933c 100644 --- a/pkg/crosscluster/logical/replication_statements.go +++ b/pkg/crosscluster/logical/replication_statements.go @@ -19,6 +19,7 @@ import ( type columnSchema struct { column catalog.Column + columnType *types.T isPrimaryKey bool isComputed bool } @@ -58,6 +59,7 @@ func getColumnSchema(table catalog.TableDescriptor) []columnSchema { result = append(result, columnSchema{ column: col, + columnType: col.GetType().Canonical(), isPrimaryKey: isPrimaryKey[col.GetID()], isComputed: isComputed, }) @@ -86,12 +88,14 @@ func newTypedPlaceholder(idx int, col catalog.Column) (*tree.CastExpr, error) { // in the table. Parameters are ordered by column ID. func newInsertStatement( table catalog.TableDescriptor, -) (statements.Statement[tree.Statement], error) { +) (statements.Statement[tree.Statement], []*types.T, error) { columns := getColumnSchema(table) - columnNames := make(tree.NameList, 0, len(columns)) parameters := make(tree.Exprs, 0, len(columns)) + paramTypes := make([]*types.T, 0, len(columns)) for i, col := range columns { + paramTypes = append(paramTypes, col.columnType) + // NOTE: this consumes a placholder ID because its part of the tree.Datums, // but it doesn't show up in the query because computed columns are not // needed for insert statements. @@ -102,7 +106,7 @@ func newInsertStatement( var err error parameter, err := newTypedPlaceholder(i+1, col.column) if err != nil { - return statements.Statement[tree.Statement]{}, err + return statements.Statement[tree.Statement]{}, nil, err } columnNames = append(columnNames, tree.Name(col.column.GetName())) @@ -129,7 +133,11 @@ func newInsertStatement( Returning: tree.AbsentReturningClause, } - return toParsedStatement(insert) + stmt, err := toParsedStatement(insert) + if err != nil { + return statements.Statement[tree.Statement]{}, nil, err + } + return stmt, paramTypes, nil } // newMatchesLastRow creates a WHERE clause for matching all columns of a row. @@ -183,16 +191,16 @@ func newMatchesLastRow(columns []columnSchema, startParamIdx int) (tree.Expr, er // Parameters are ordered by column ID. func newUpdateStatement( table catalog.TableDescriptor, -) (statements.Statement[tree.Statement], error) { +) (statements.Statement[tree.Statement], []*types.T, error) { columns := getColumnSchema(table) - // Create WHERE clause for matching the previous row values whereClause, err := newMatchesLastRow(columns, 1) if err != nil { - return statements.Statement[tree.Statement]{}, err + return statements.Statement[tree.Statement]{}, nil, err } exprs := make(tree.UpdateExprs, 0, len(columns)) + paramTypes := make([]*types.T, 0, 2*len(columns)) for i, col := range columns { if col.isComputed { // Skip computed columns since they are not needed to fully specify the @@ -208,7 +216,7 @@ func newUpdateStatement( // are for the where clause. placeholder, err := newTypedPlaceholder(len(columns)+i+1, col.column) if err != nil { - return statements.Statement[tree.Statement]{}, err + return statements.Statement[tree.Statement]{}, nil, err } exprs = append(exprs, &tree.UpdateExpr{ @@ -217,6 +225,15 @@ func newUpdateStatement( }) } + // Add parameter types for WHERE clause (previous values) + for _, col := range columns { + paramTypes = append(paramTypes, col.columnType) + } + // Add parameter types for SET clause (new values) + for _, col := range columns { + paramTypes = append(paramTypes, col.columnType) + } + // Create the final update statement update := &tree.Update{ Table: &tree.TableRef{ @@ -228,7 +245,12 @@ func newUpdateStatement( Returning: tree.AbsentReturningClause, } - return toParsedStatement(update) + stmt, err := toParsedStatement(update) + if err != nil { + return statements.Statement[tree.Statement]{}, nil, err + } + + return stmt, paramTypes, nil } // newDeleteStatement returns a statement that can be used to delete a row from @@ -239,13 +261,19 @@ func newUpdateStatement( // Parameters are ordered by column ID. func newDeleteStatement( table catalog.TableDescriptor, -) (statements.Statement[tree.Statement], error) { +) (statements.Statement[tree.Statement], []*types.T, error) { columns := getColumnSchema(table) // Create WHERE clause for matching the row to delete whereClause, err := newMatchesLastRow(columns, 1) if err != nil { - return statements.Statement[tree.Statement]{}, err + return statements.Statement[tree.Statement]{}, nil, err + } + + // Create parameter types for WHERE clause + paramTypes := make([]*types.T, 0, len(columns)) + for _, col := range columns { + paramTypes = append(paramTypes, col.columnType) } // Create the final delete statement @@ -261,7 +289,11 @@ func newDeleteStatement( Returning: &tree.ReturningExprs{tree.StarSelectExpr()}, } - return toParsedStatement(delete) + stmt, err := toParsedStatement(delete) + if err != nil { + return statements.Statement[tree.Statement]{}, nil, err + } + return stmt, paramTypes, nil } // newBulkSelectStatement returns a statement that can be used to query @@ -289,26 +321,32 @@ func newDeleteStatement( // AND replication_target.secondary_id = key_list.key2 func newBulkSelectStatement( table catalog.TableDescriptor, -) (statements.Statement[tree.Statement], error) { +) (statements.Statement[tree.Statement], []*types.T, error) { cols := getColumnSchema(table) - primaryKeyColumns := make([]catalog.Column, 0, len(cols)) + primaryKeyColumns := make([]columnSchema, 0, len(cols)) for _, col := range cols { if col.isPrimaryKey { - primaryKeyColumns = append(primaryKeyColumns, col.column) + primaryKeyColumns = append(primaryKeyColumns, col) } } + // Create parameter types for primary key arrays + paramTypes := make([]*types.T, 0, len(primaryKeyColumns)) + for _, pkCol := range primaryKeyColumns { + paramTypes = append(paramTypes, types.MakeArray(pkCol.columnType)) + } + // keyListName is the name of the CTE that contains the primary keys supplied // via array parameters. keyListName, err := tree.NewUnresolvedObjectName(1, [3]string{"key_list"}, tree.NoAnnotation) if err != nil { - return statements.Statement[tree.Statement]{}, err + return statements.Statement[tree.Statement]{}, nil, err } // targetName is used to name the user's table. targetName, err := tree.NewUnresolvedObjectName(1, [3]string{"replication_target"}, tree.NoAnnotation) if err != nil { - return statements.Statement[tree.Statement]{}, err + return statements.Statement[tree.Statement]{}, nil, err } // Create the `SELECT unnest($1::[]INT, $2::[]INT) WITH ORDINALITY AS key_list(key1, key2, index)` table expression. @@ -320,7 +358,7 @@ func newBulkSelectStatement( }) primaryKeyExprs = append(primaryKeyExprs, &tree.CastExpr{ Expr: &tree.Placeholder{Idx: tree.PlaceholderIdx(i)}, - Type: types.MakeArray(pkCol.GetType()), + Type: types.MakeArray(pkCol.columnType), SyntaxMode: tree.CastShort, }) } @@ -379,7 +417,7 @@ func newBulkSelectStatement( // Construct the JOIN clause for the final query. var joinCond tree.Expr for i, pkCol := range primaryKeyColumns { - colName := tree.Name(pkCol.GetName()) + colName := tree.Name(pkCol.column.GetName()) keyColName := fmt.Sprintf("key%d", i+1) eqExpr := &tree.ComparisonExpr{ @@ -430,7 +468,11 @@ func newBulkSelectStatement( }, } - return toParsedStatement(selectStmt) + stmt, err := toParsedStatement(selectStmt) + if err != nil { + return statements.Statement[tree.Statement]{}, nil, err + } + return stmt, paramTypes, nil } func toParsedStatement(stmt tree.Statement) (statements.Statement[tree.Statement], error) { diff --git a/pkg/crosscluster/logical/replication_statements_test.go b/pkg/crosscluster/logical/replication_statements_test.go index ae50f0951e39..44a2dcd7b262 100644 --- a/pkg/crosscluster/logical/replication_statements_test.go +++ b/pkg/crosscluster/logical/replication_statements_test.go @@ -10,7 +10,6 @@ import ( gosql "database/sql" "fmt" "math/rand" - "slices" "testing" "github.com/cockroachdb/cockroach/pkg/base" @@ -68,16 +67,7 @@ func TestReplicationStatements(t *testing.T) { asSql := tree.Serialize(&p) _, err := db.Exec(asSql) - require.NoError(t, err) - } - - getTypes := func(desc catalog.TableDescriptor) []*types.T { - columns := getColumnSchema(desc) - types := make([]*types.T, len(columns)) - for i, col := range columns { - types[i] = col.column.GetType() - } - return types + require.NoError(t, err, "statement: %s", asSql) } datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) { @@ -95,10 +85,10 @@ func TestReplicationStatements(t *testing.T) { desc := getTableDesc(tableName) - insertStmt, err := newInsertStatement(desc) + insertStmt, types, err := newInsertStatement(desc) require.NoError(t, err) - prepareStatement(t, sqlDB, getTypes(desc), insertStmt) + prepareStatement(t, sqlDB, types, insertStmt) return insertStmt.SQL case "show-update": @@ -107,12 +97,11 @@ func TestReplicationStatements(t *testing.T) { desc := getTableDesc(tableName) - updateStmt, err := newUpdateStatement(desc) + updateStmt, types, err := newUpdateStatement(desc) require.NoError(t, err) // update expects previous and current values to be passed as // parameters. - types := slices.Concat(getTypes(desc), getTypes(desc)) prepareStatement(t, sqlDB, types, updateStmt) return updateStmt.SQL @@ -124,10 +113,9 @@ func TestReplicationStatements(t *testing.T) { // delete expects previous and current values to be passed as // parameters. - deleteStmt, err := newDeleteStatement(desc) + deleteStmt, types, err := newDeleteStatement(desc) require.NoError(t, err) - types := slices.Concat(getTypes(desc), getTypes(desc)) prepareStatement(t, sqlDB, types, deleteStmt) return deleteStmt.SQL @@ -137,18 +125,10 @@ func TestReplicationStatements(t *testing.T) { desc := getTableDesc(tableName) - stmt, err := newBulkSelectStatement(desc) + stmt, types, err := newBulkSelectStatement(desc) require.NoError(t, err) - allColumns := getColumnSchema(desc) - var primaryKeyTypes []*types.T - for _, col := range allColumns { - if col.isPrimaryKey { - primaryKeyTypes = append(primaryKeyTypes, types.MakeArray(col.column.GetType())) - } - } - - prepareStatement(t, sqlDB, primaryKeyTypes, stmt) + prepareStatement(t, sqlDB, types, stmt) return stmt.SQL default: diff --git a/pkg/crosscluster/logical/sql_crud_writer.go b/pkg/crosscluster/logical/sql_crud_writer.go index 0c577c72c413..008bffa26966 100644 --- a/pkg/crosscluster/logical/sql_crud_writer.go +++ b/pkg/crosscluster/logical/sql_crud_writer.go @@ -44,13 +44,21 @@ func newCrudSqlWriter( discard jobspb.LogicalReplicationDetails_Discard, procConfigByDestID map[descpb.ID]sqlProcessorTableConfig, jobID jobspb.JobID, -) (BatchHandler, error) { +) (_ BatchHandler, err error) { decoder, err := newEventDecoder(ctx, cfg.DB, evalCtx.Settings, procConfigByDestID) if err != nil { return nil, err } handlers := make(map[descpb.ID]*tableHandler) + defer func() { + if err != nil { + for _, handler := range handlers { + handler.Close(ctx) + } + } + }() + for dstDescID := range procConfigByDestID { handler, err := newTableHandler( ctx, @@ -126,6 +134,9 @@ func eventsByTable(events []decodedEvent) func(yield func(descpb.ID, []decodedEv // Close implements BatchHandler. func (c *sqlCrudWriter) Close(ctx context.Context) { + for _, handler := range c.handlers { + handler.Close(ctx) + } } // GetLastRow implements BatchHandler. diff --git a/pkg/crosscluster/logical/sql_row_reader.go b/pkg/crosscluster/logical/sql_row_reader.go index cd4fff5860cb..c58a3492c874 100644 --- a/pkg/crosscluster/logical/sql_row_reader.go +++ b/pkg/crosscluster/logical/sql_row_reader.go @@ -10,16 +10,16 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/isql" - "github.com/cockroachdb/cockroach/pkg/sql/parser/statements" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/errors" ) type sqlRowReader struct { - selectStatement statements.Statement[tree.Statement] - sessionOverride sessiondata.InternalExecutorOverride + session isql.Session + + selectStatement isql.PreparedStatement + // keyColumnIndices is the index of the datums that are part of the primary key. keyColumnIndices []int columns []columnSchema @@ -41,7 +41,7 @@ type priorRow struct { } func newSQLRowReader( - table catalog.TableDescriptor, sessionOverride sessiondata.InternalExecutorOverride, + ctx context.Context, table catalog.TableDescriptor, session isql.Session, ) (*sqlRowReader, error) { cols := getColumnSchema(table) keyColumns := make([]int, 0, len(cols)) @@ -51,14 +51,18 @@ func newSQLRowReader( } } - selectStatement, err := newBulkSelectStatement(table) + selectStatementRaw, types, err := newBulkSelectStatement(table) + if err != nil { + return nil, err + } + selectStatement, err := session.Prepare(ctx, "replication-read-refresh", selectStatementRaw, types) if err != nil { return nil, err } return &sqlRowReader{ + session: session, selectStatement: selectStatement, - sessionOverride: sessionOverride, keyColumnIndices: keyColumns, columns: cols, }, nil @@ -69,9 +73,7 @@ func newSQLRowReader( // the input is the key to the output map. // // E.g. result[i] and rows[i] are the same row. -func (r *sqlRowReader) ReadRows( - ctx context.Context, txn isql.Txn, rows []tree.Datums, -) (map[int]priorRow, error) { +func (r *sqlRowReader) ReadRows(ctx context.Context, rows []tree.Datums) (map[int]priorRow, error) { // TODO(jeffswenson): optimize allocations. It may require a change to the // API. For now, this probably isn't a performance bottleneck because: // 1. Many of the allocations are one per batch instead of one per row. @@ -82,7 +84,7 @@ func (r *sqlRowReader) ReadRows( return nil, nil } - params := make([]any, 0, len(r.keyColumnIndices)) + params := make([]tree.Datum, 0, len(r.keyColumnIndices)) for _, index := range r.keyColumnIndices { array := tree.NewDArray(r.columns[index].column.GetType()) for _, row := range rows { @@ -93,14 +95,10 @@ func (r *sqlRowReader) ReadRows( params = append(params, array) } - // Execute the query using QueryBufferedEx which returns all rows at once. + // Execute the query using QueryPrepared which returns all rows at once. // This is okay since we already know the batch is small enough to fit in // memory. - rows, err := txn.QueryBufferedEx(ctx, "replication-read-refresh", txn.KV(), - r.sessionOverride, - r.selectStatement.SQL, - params..., - ) + rows, err := r.session.QueryPrepared(ctx, r.selectStatement, params) if err != nil { return nil, err } diff --git a/pkg/crosscluster/logical/sql_row_reader_test.go b/pkg/crosscluster/logical/sql_row_reader_test.go index 33d6be691420..c5f336c5c968 100644 --- a/pkg/crosscluster/logical/sql_row_reader_test.go +++ b/pkg/crosscluster/logical/sql_row_reader_test.go @@ -17,7 +17,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/desctestutils" "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" @@ -53,12 +52,16 @@ func TestSQLRowReader(t *testing.T) { // Create sqlRowReader for source table srcDesc := desctestutils.TestingGetPublicTableDescriptor(s.DB(), s.Codec(), "a", "tab") - srcReader, err := newSQLRowReader(srcDesc, sessiondata.InternalExecutorOverride{}) + srcSession := newInternalSession(t, s) + defer srcSession.Close(ctx) + srcReader, err := newSQLRowReader(ctx, srcDesc, srcSession) require.NoError(t, err) // Create sqlRowReader for destination table dstDesc := desctestutils.TestingGetPublicTableDescriptor(s.DB(), s.Codec(), "b", "tab") - dstReader, err := newSQLRowReader(dstDesc, sessiondata.InternalExecutorOverride{}) + dstSession := newInternalSession(t, server.Server(0)) + defer dstSession.Close(ctx) + dstReader, err := newSQLRowReader(ctx, dstDesc, dstSession) require.NoError(t, err) // Create test rows to look up @@ -69,7 +72,7 @@ func TestSQLRowReader(t *testing.T) { readRows := func(t *testing.T, db isql.DB, rows []tree.Datums, reader *sqlRowReader) map[int]priorRow { var result map[int]priorRow require.NoError(t, db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - result, err = reader.ReadRows(ctx, txn, rows) + result, err = reader.ReadRows(ctx, rows) require.NoError(t, err) return err })) diff --git a/pkg/crosscluster/logical/sql_row_writer.go b/pkg/crosscluster/logical/sql_row_writer.go index 85ee19638892..37a0210c22d4 100644 --- a/pkg/crosscluster/logical/sql_row_writer.go +++ b/pkg/crosscluster/logical/sql_row_writer.go @@ -10,9 +10,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/isql" - "github.com/cockroachdb/cockroach/pkg/sql/parser/statements" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/sessionmutator" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/errors" ) @@ -25,42 +24,40 @@ var errStalePreviousValue = errors.New("stale previous value") // sqlRowWriter is configured to write rows to a specific table and descriptor // version. type sqlRowWriter struct { - insert statements.Statement[tree.Statement] - update statements.Statement[tree.Statement] - delete statements.Statement[tree.Statement] - sessionOverride sessiondata.InternalExecutorOverride + session isql.Session - scratchDatums []any + insert isql.PreparedStatement + update isql.PreparedStatement + delete isql.PreparedStatement + + scratchDatums tree.Datums columns []string } -func (s *sqlRowWriter) getExecutorOverride( - originTimestamp hlc.Timestamp, -) sessiondata.InternalExecutorOverride { - session := s.sessionOverride - session.OriginTimestampForLogicalDataReplication = originTimestamp - session.OriginIDForLogicalDataReplication = 1 - return session +func (s *sqlRowWriter) setOriginTimestamp( + ctx context.Context, originTimestamp hlc.Timestamp, +) error { + return s.session.ModifySession(ctx, func(m sessionmutator.SessionDataMutator) { + m.Data.OriginTimestampForLogicalDataReplication = originTimestamp + }) } // DeleteRow deletes a row from the table. It returns errStalePreviousValue // if the oldRow argument does not match the value in the local database. func (s *sqlRowWriter) DeleteRow( - ctx context.Context, txn isql.Txn, originTimestamp hlc.Timestamp, oldRow tree.Datums, + ctx context.Context, originTimestamp hlc.Timestamp, oldRow tree.Datums, ) error { s.scratchDatums = s.scratchDatums[:0] + s.scratchDatums = append(s.scratchDatums, oldRow...) - for _, d := range oldRow { - s.scratchDatums = append(s.scratchDatums, d) + err := s.setOriginTimestamp(ctx, originTimestamp) + if err != nil { + return err } - rowsAffected, err := txn.ExecParsed(ctx, "replicated-delete", txn.KV(), - s.getExecutorOverride(originTimestamp), - s.delete, - s.scratchDatums..., - ) + rowsAffected, err := s.session.ExecutePrepared(ctx, s.delete, s.scratchDatums) if err != nil { - return err + return errors.Wrap(err, "deleting row") } if rowsAffected != 1 { return errStalePreviousValue @@ -71,20 +68,20 @@ func (s *sqlRowWriter) DeleteRow( // InsertRow inserts a row into the table. It will return an error if the row // already exists. func (s *sqlRowWriter) InsertRow( - ctx context.Context, txn isql.Txn, originTimestamp hlc.Timestamp, row tree.Datums, + ctx context.Context, originTimestamp hlc.Timestamp, row tree.Datums, ) error { s.scratchDatums = s.scratchDatums[:0] - for _, d := range row { - s.scratchDatums = append(s.scratchDatums, d) - } - rowsImpacted, err := txn.ExecParsed(ctx, "replicated-insert", txn.KV(), - s.getExecutorOverride(originTimestamp), - s.insert, - s.scratchDatums..., - ) + s.scratchDatums = append(s.scratchDatums, row...) + + err := s.setOriginTimestamp(ctx, originTimestamp) if err != nil { return err } + + rowsImpacted, err := s.session.ExecutePrepared(ctx, s.insert, s.scratchDatums) + if err != nil { + return errors.Wrap(err, "inserting row") + } if rowsImpacted != 1 { return errors.AssertionFailedf("expected 1 row impacted, got %d", rowsImpacted) } @@ -94,28 +91,20 @@ func (s *sqlRowWriter) InsertRow( // UpdateRow updates a row in the table. It returns errStalePreviousValue // if the oldRow argument does not match the value in the local database. func (s *sqlRowWriter) UpdateRow( - ctx context.Context, - txn isql.Txn, - originTimestamp hlc.Timestamp, - oldRow tree.Datums, - newRow tree.Datums, + ctx context.Context, originTimestamp hlc.Timestamp, oldRow tree.Datums, newRow tree.Datums, ) error { s.scratchDatums = s.scratchDatums[:0] + s.scratchDatums = append(s.scratchDatums, oldRow...) + s.scratchDatums = append(s.scratchDatums, newRow...) - for _, d := range oldRow { - s.scratchDatums = append(s.scratchDatums, d) - } - for _, d := range newRow { - s.scratchDatums = append(s.scratchDatums, d) + err := s.setOriginTimestamp(ctx, originTimestamp) + if err != nil { + return err } - rowsAffected, err := txn.ExecParsed(ctx, "replicated-update", txn.KV(), - s.getExecutorOverride(originTimestamp), - s.update, - s.scratchDatums..., - ) + rowsAffected, err := s.session.ExecutePrepared(ctx, s.update, s.scratchDatums) if err != nil { - return err + return errors.Wrap(err, "updating row") } if rowsAffected != 1 { return errStalePreviousValue @@ -124,7 +113,7 @@ func (s *sqlRowWriter) UpdateRow( } func newSQLRowWriter( - table catalog.TableDescriptor, sessionOverride sessiondata.InternalExecutorOverride, + ctx context.Context, table catalog.TableDescriptor, session isql.Session, ) (*sqlRowWriter, error) { columnsToDecode := getColumnSchema(table) columns := make([]string, len(columnsToDecode)) @@ -132,33 +121,38 @@ func newSQLRowWriter( columns[i] = col.column.GetName() } - // TODO(jeffswenson): figure out how to manage prepared statements and - // transactions in an internal executor. The original plan was to prepare - // statements on initialization then reuse them, but the internal executor - // is scoped to a single transaction and I couldn't figure out how to - // maintain prepared statements across different instances of the internal - // executor. - - insert, err := newInsertStatement(table) + insert, insertParamTypes, err := newInsertStatement(table) if err != nil { return nil, err } + preparedInsert, err := session.Prepare(ctx, "insert", insert, insertParamTypes) + if err != nil { + return nil, errors.Wrap(err, "unable to prepare insert statement") + } - update, err := newUpdateStatement(table) + update, updateParamTypes, err := newUpdateStatement(table) if err != nil { return nil, err } + preparedUpdate, err := session.Prepare(ctx, "update", update, updateParamTypes) + if err != nil { + return nil, errors.Wrap(err, "unable to prepare update statement") + } - delete, err := newDeleteStatement(table) + delete, deleteParamTypes, err := newDeleteStatement(table) if err != nil { return nil, err } + preparedDelete, err := session.Prepare(ctx, "delete", delete, deleteParamTypes) + if err != nil { + return nil, errors.Wrap(err, "unable to prepare delete statement") + } return &sqlRowWriter{ - insert: insert, - update: update, - delete: delete, - sessionOverride: sessionOverride, - columns: columns, + session: session, + insert: preparedInsert, + update: preparedUpdate, + delete: preparedDelete, + columns: columns, }, nil } diff --git a/pkg/crosscluster/logical/sql_row_writer_test.go b/pkg/crosscluster/logical/sql_row_writer_test.go index 3e89d3db032f..2bd0efcae1e8 100644 --- a/pkg/crosscluster/logical/sql_row_writer_test.go +++ b/pkg/crosscluster/logical/sql_row_writer_test.go @@ -11,12 +11,14 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/cdctest" + "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/stretchr/testify/require" ) @@ -30,6 +32,18 @@ func makeTestRow(t *testing.T, _ catalog.TableDescriptor, id int64, name string) } } +func newInternalSession(t *testing.T, s serverutils.ApplicationLayerInterface) isql.Session { + sd := tableHandlerSessionSettings(sql.NewInternalSessionData(context.Background(), s.ClusterSettings(), "")) + session, err := s.InternalDB().(isql.DB).Session(context.Background(), "test_session", isql.WithSessionData(sd)) + require.NoError(t, err) + return session +} + +// hlcToString converts an HLC timestamp to the format used by crdb_internal_origin_timestamp +func hlcToString(ts hlc.Timestamp) string { + return eval.TimestampToDecimalDatum(ts).String() +} + func TestSQLRowWriter(t *testing.T) { defer leaktest.AfterTest(t)() @@ -38,7 +52,6 @@ func TestSQLRowWriter(t *testing.T) { defer s.Stopper().Stop(ctx) sqlDB := sqlutils.MakeSQLRunner(db) - internalDB := s.InternalDB().(isql.DB) // Create a test table sqlDB.Exec(t, ` @@ -49,47 +62,58 @@ func TestSQLRowWriter(t *testing.T) { ) `) + session := newInternalSession(t, s) + defer session.Close(ctx) + // Create a row writer desc := cdctest.GetHydratedTableDescriptor(t, s.ApplicationLayer().ExecutorConfig(), "test_table") - writer, err := newSQLRowWriter(desc, sessiondata.InternalExecutorOverride{}) + writer, err := newSQLRowWriter(ctx, desc, session) require.NoError(t, err) // Test InsertRow insertRow := makeTestRow(t, desc, 1, "test") - require.NoError(t, internalDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - return writer.InsertRow(ctx, txn, s.Clock().Now(), insertRow) - })) + insertTimestamp := s.Clock().Now() + require.NoError(t, writer.InsertRow(ctx, insertTimestamp, insertRow)) require.Equal(t, - [][]string{{"1", "test", "NULL"}}, - sqlDB.QueryStr(t, "SELECT id, name, is_always_null FROM test_table WHERE id = 1")) + [][]string{{"1", "test", "NULL", hlcToString(insertTimestamp), "1"}}, + sqlDB.QueryStr(t, "SELECT id, name, is_always_null, crdb_internal_origin_timestamp, crdb_internal_origin_id FROM test_table WHERE id = 1")) // Test UpdateRow updateRow := makeTestRow(t, desc, 1, "updated") - require.NoError(t, internalDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - return writer.UpdateRow(ctx, txn, s.Clock().Now(), insertRow, updateRow) - })) + updateTimestamp := s.Clock().Now() + require.NoError(t, writer.UpdateRow(ctx, updateTimestamp, insertRow, updateRow)) require.Equal(t, - [][]string{{"1", "updated", "NULL"}}, - sqlDB.QueryStr(t, "SELECT id, name, is_always_null FROM test_table WHERE id = 1")) + [][]string{{"1", "updated", "NULL", hlcToString(updateTimestamp), "1"}}, + sqlDB.QueryStr(t, "SELECT id, name, is_always_null, crdb_internal_origin_timestamp, crdb_internal_origin_id FROM test_table WHERE id = 1")) // Test UpdateRow with stale previous value staleRow := makeTestRow(t, desc, 1, "test") // Using old value - err = internalDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - return writer.UpdateRow(ctx, txn, s.Clock().Now(), staleRow, updateRow) - }) + err = writer.UpdateRow(ctx, s.Clock().Now(), staleRow, updateRow) require.ErrorIs(t, err, errStalePreviousValue) - // Test DeleteRow - require.NoError(t, internalDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - return writer.DeleteRow(ctx, txn, s.Clock().Now(), updateRow) - })) + // Test DeleteRow - first insert a test row to verify origin data before delete + insertRow2 := makeTestRow(t, desc, 2, "to_delete") + insertTimestamp2 := s.Clock().Now() + require.NoError(t, writer.InsertRow(ctx, insertTimestamp2, insertRow2)) + + // Verify the row exists with correct data and origin metadata + require.Equal(t, + [][]string{{"2", "to_delete", "NULL", hlcToString(insertTimestamp2), "1"}}, + sqlDB.QueryStr(t, "SELECT id, name, is_always_null, crdb_internal_origin_timestamp, crdb_internal_origin_id FROM test_table WHERE id = 2")) + + // Now delete the row + require.NoError(t, writer.DeleteRow(ctx, s.Clock().Now(), insertRow2)) + require.Equal(t, + [][]string{}, + sqlDB.QueryStr(t, "SELECT id, name, is_always_null, crdb_internal_origin_timestamp, crdb_internal_origin_id FROM test_table WHERE id = 2")) + + // Also delete the original row to clean up + require.NoError(t, writer.DeleteRow(ctx, s.Clock().Now(), updateRow)) require.Equal(t, [][]string{}, - sqlDB.QueryStr(t, "SELECT id, name, is_always_null FROM test_table WHERE id = 1")) + sqlDB.QueryStr(t, "SELECT id, name, is_always_null, crdb_internal_origin_timestamp, crdb_internal_origin_id FROM test_table WHERE id = 1")) // Test DeleteRow with stale value - err = internalDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - return writer.DeleteRow(ctx, txn, s.Clock().Now(), staleRow) - }) + err = writer.DeleteRow(ctx, s.Clock().Now(), staleRow) require.ErrorIs(t, err, errStalePreviousValue) } diff --git a/pkg/crosscluster/logical/table_batch_handler.go b/pkg/crosscluster/logical/table_batch_handler.go index 4374e2c4ebcf..1ef212832b4d 100644 --- a/pkg/crosscluster/logical/table_batch_handler.go +++ b/pkg/crosscluster/logical/table_batch_handler.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/isql" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/errors" ) @@ -27,6 +28,7 @@ import ( type tableHandler struct { sqlReader *sqlRowReader sqlWriter *sqlRowWriter + session isql.Session db descs.DB tombstoneUpdater *tombstoneUpdater } @@ -74,6 +76,17 @@ func (t *tableBatchStats) AddTo(bs *batchStats) { } } +func tableHandlerSessionSettings(sd *sessiondata.SessionData) *sessiondata.SessionData { + sd = sd.Clone() + sd.PlanCacheMode = sessiondatapb.PlanCacheModeForceGeneric + sd.VectorizeMode = sessiondatapb.VectorizeOff + // TODO(jeffswenson): enable swap mutation once update swap merges + //sd.UseSwapMutations = true + sd.BufferedWritesEnabled = false + sd.OriginIDForLogicalDataReplication = 1 + return sd +} + // newTableHandler creates a new tableHandler for the given table descriptor ID. // It internally constructs the sqlReader and sqlWriter components. func newTableHandler( @@ -85,13 +98,23 @@ func newTableHandler( jobID jobspb.JobID, leaseMgr *lease.Manager, settings *cluster.Settings, -) (*tableHandler, error) { +) (_ *tableHandler, err error) { var table catalog.TableDescriptor + sd = tableHandlerSessionSettings(sd) + session, err := db.Session(ctx, "logical-data-replication", isql.WithSessionData(sd)) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + session.Close(ctx) + } + }() // NOTE: we don't hold a lease on the table descriptor, but validation // prevents users from changing the set of columns or the primary key of an // LDR replicated table. - err := db.DescsTxn(ctx, func(ctx context.Context, txn descs.Txn) error { + err = db.DescsTxn(ctx, func(ctx context.Context, txn descs.Txn) error { var err error table, err = txn.Descriptors().GetLeasedImmutableTableByID(ctx, txn.KV(), tableID) return err @@ -105,12 +128,12 @@ func newTableHandler( sessionOverride := ieOverrideBase sessionOverride.ApplicationName = fmt.Sprintf("%s-logical-replication-%d", sd.ApplicationName, jobID) - reader, err := newSQLRowReader(table, sessionOverride) + reader, err := newSQLRowReader(ctx, table, session) if err != nil { return nil, err } - writer, err := newSQLRowWriter(table, sessionOverride) + writer, err := newSQLRowWriter(ctx, table, session) if err != nil { return nil, err } @@ -122,9 +145,14 @@ func newTableHandler( sqlWriter: writer, db: db, tombstoneUpdater: tombstoneUpdater, + session: session, }, nil } +func (t *tableHandler) Close(ctx context.Context) { + t.session.Close(ctx) +} + func (t *tableHandler) handleDecodedBatch( ctx context.Context, batch []decodedEvent, ) (tableBatchStats, error) { @@ -152,26 +180,25 @@ func (t *tableHandler) attemptBatch( ctx context.Context, batch []decodedEvent, ) (tableBatchStats, error) { var stats tableBatchStats - err := t.db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + + var hasTombstoneUpdates bool + session := t.sqlWriter.session + err := session.Txn(ctx, func(ctx context.Context) error { for _, event := range batch { switch { case event.isDelete && len(event.prevRow) != 0: stats.deletes++ - err := t.sqlWriter.DeleteRow(ctx, txn, event.originTimestamp, event.prevRow) + err := t.sqlWriter.DeleteRow(ctx, event.originTimestamp, event.prevRow) if err != nil { return err } case event.isDelete && len(event.prevRow) == 0: - stats.tombstoneUpdates++ - tombstoneUpdateStats, err := t.tombstoneUpdater.updateTombstone(ctx, txn, event.originTimestamp, event.row) - if err != nil { - return err - } - stats.kvLwwLosers += tombstoneUpdateStats.kvWriteTooOld + hasTombstoneUpdates = true + // Skip: handled in its own transaction. case event.prevRow == nil: stats.inserts++ - err := withSavepoint(ctx, txn.KV(), func() error { - return t.sqlWriter.InsertRow(ctx, txn, event.originTimestamp, event.row) + err := session.Savepoint(ctx, func(ctx context.Context) error { + return t.sqlWriter.InsertRow(ctx, event.originTimestamp, event.row) }) if isLwwLoser(err) { // Insert may observe a LWW failure if it attempts to write over a tombstone. @@ -183,7 +210,7 @@ func (t *tableHandler) attemptBatch( } case event.prevRow != nil: stats.updates++ - err := t.sqlWriter.UpdateRow(ctx, txn, event.originTimestamp, event.prevRow, event.row) + err := t.sqlWriter.UpdateRow(ctx, event.originTimestamp, event.prevRow, event.row) if err != nil { return err } @@ -196,6 +223,30 @@ func (t *tableHandler) attemptBatch( if err != nil { return tableBatchStats{}, err } + + if hasTombstoneUpdates { + // TODO(jeffswenson): once we have a way to expose the transaction used by + // the Session, we should bundle this with the other txn. The purpose of + // these transactions is batching writes in a transaction increases + // efficiency. The transactions are not needed for correctness. + err = t.db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + for _, event := range batch { + if event.isDelete && len(event.prevRow) == 0 { + stats.tombstoneUpdates++ + tombstoneUpdateStats, err := t.tombstoneUpdater.updateTombstone(ctx, txn, event.originTimestamp, event.row) + if err != nil { + return err + } + stats.kvLwwLosers += tombstoneUpdateStats.kvWriteTooOld + } + } + return nil + }) + if err != nil { + return tableBatchStats{}, err + } + } + return stats, nil } @@ -213,14 +264,7 @@ func (t *tableHandler) refreshPrevRows( rows = append(rows, event.row) } - var refreshedRows map[int]priorRow - err := t.db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - var err error - // TODO(jeffswenson): should we apply the batch in the same transaction - // that we perform the read refresh? We could maybe even use locking reads. - refreshedRows, err = t.sqlReader.ReadRows(ctx, txn, rows) - return err - }) + refreshedRows, err := t.sqlReader.ReadRows(ctx, rows) if err != nil { return nil, tableBatchStats{}, err } diff --git a/pkg/crosscluster/logical/table_batch_handler_test.go b/pkg/crosscluster/logical/table_batch_handler_test.go index 4d5171840481..73f354c1d802 100644 --- a/pkg/crosscluster/logical/table_batch_handler_test.go +++ b/pkg/crosscluster/logical/table_batch_handler_test.go @@ -16,8 +16,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" - "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" - "github.com/cockroachdb/cockroach/pkg/sql/execinfra" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" @@ -34,14 +32,12 @@ func newCrudBatchHandler( ctx := context.Background() desc := cdctest.GetHydratedTableDescriptor(t, s.ExecutorConfig(), tree.Name(tableName)) sd := sql.NewInternalSessionData(ctx, s.ClusterSettings(), "" /* opName */) + + executorConfig := s.ExecutorConfig().(sql.ExecutorConfig) + handler, err := newCrudSqlWriter( ctx, - &execinfra.ServerConfig{ - DB: s.InternalDB().(descs.DB), - Codec: s.Codec(), - LeaseManager: s.LeaseManager(), - Settings: s.ClusterSettings(), - }, + &executorConfig.DistSQLSrv.ServerConfig, &eval.Context{ Codec: s.Codec(), Settings: s.ClusterSettings(), @@ -79,6 +75,7 @@ func TestBatchHandlerFastPath(t *testing.T) { `) handler, desc := newCrudBatchHandler(t, s, "test_table") + defer handler.Close(ctx) defer handler.ReleaseLeases(ctx) eb := newKvEventBuilder(t, desc.TableDesc()) @@ -140,6 +137,7 @@ func TestBatchHandlerSlowPath(t *testing.T) { `) handler, desc := newCrudBatchHandler(t, s, "test_table") + defer handler.Close(ctx) defer handler.ReleaseLeases(ctx) eb := newKvEventBuilder(t, desc.TableDesc()) @@ -200,11 +198,12 @@ func TestBatchHandlerDuplicateBatchEntries(t *testing.T) { runner.Exec(t, ` CREATE TABLE test_table ( id INT PRIMARY KEY, - value STRING + value VARCHAR(100) ) `) handler, desc := newCrudBatchHandler(t, s, "test_table") + defer handler.Close(ctx) defer handler.ReleaseLeases(ctx) eb := newKvEventBuilder(t, desc.TableDesc()) diff --git a/pkg/crosscluster/logical/testdata/ldr_statements b/pkg/crosscluster/logical/testdata/ldr_statements index a302c6c6e9c8..86f8fa804ef3 100644 --- a/pkg/crosscluster/logical/testdata/ldr_statements +++ b/pkg/crosscluster/logical/testdata/ldr_statements @@ -92,7 +92,7 @@ DELETE FROM [108 AS replication_target] WHERE ((((id = $1::INT8) AND (name IS NO # discount_price is included because its part of the primary key. show-select table=products ---- -SELECT key_list.index, replication_target.crdb_internal_origin_timestamp, replication_target.crdb_internal_mvcc_timestamp, replication_target.id, replication_target.name, replication_target.unit_price, replication_target.quantity, replication_target.total_value, replication_target.last_updated FROM ROWS FROM (unnest($1::INT8[], $2::DECIMAL(10,2)[])) WITH ORDINALITY AS key_list (key1, key2, index) INNER LOOKUP JOIN [108 AS replication_target] ON (replication_target.id = key_list.key1) AND (replication_target.total_value = key_list.key2) +SELECT key_list.index, replication_target.crdb_internal_origin_timestamp, replication_target.crdb_internal_mvcc_timestamp, replication_target.id, replication_target.name, replication_target.unit_price, replication_target.quantity, replication_target.total_value, replication_target.last_updated FROM ROWS FROM (unnest($1::INT8[], $2::DECIMAL[])) WITH ORDINALITY AS key_list (key1, key2, index) INNER LOOKUP JOIN [108 AS replication_target] ON (replication_target.id = key_list.key1) AND (replication_target.total_value = key_list.key2) # Test a table with an expression, inverted index, and partial index. The # indexes are not expected to impact the INSERT/UPDATE/DELETE statements. diff --git a/pkg/sql/sqlclustersettings/clustersettings.go b/pkg/sql/sqlclustersettings/clustersettings.go index 93076406e2a8..1ec5db45518f 100644 --- a/pkg/sql/sqlclustersettings/clustersettings.go +++ b/pkg/sql/sqlclustersettings/clustersettings.go @@ -144,7 +144,8 @@ var LDRImmediateModeWriter = settings.RegisterStringSetting( settings.ApplicationLevel, "logical_replication.consumer.immediate_mode_writer", "the writer to use when in immediate mode", - metamorphic.ConstantWithTestChoice("logical_replication.consumer.immediate_mode_writer", string(LDRWriterTypeLegacyKV), string(LDRWriterTypeSQL)), + // TODO(jeffswenson): make the crud writer the default + metamorphic.ConstantWithTestChoice("logical_replication.consumer.immediate_mode_writer", string(LDRWriterTypeLegacyKV), string(LDRWriterTypeCRUD), string(LDRWriterTypeSQL)), settings.WithValidateString(func(sv *settings.Values, val string) error { if val != string(LDRWriterTypeSQL) && val != string(LDRWriterTypeLegacyKV) && val != string(LDRWriterTypeCRUD) { return errors.Newf("immediate mode writer must be either 'sql', 'legacy-kv', or 'crud', got '%s'", val)