Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/crosscluster/logical/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ go_library(
"//pkg/sql/sem/tree/treecmp",
"//pkg/sql/sessiondata",
"//pkg/sql/sessiondatapb",
"//pkg/sql/sessionmutator",
"//pkg/sql/sqlclustersettings",
"//pkg/sql/stats",
"//pkg/sql/syntheticprivilege",
Expand Down
1 change: 1 addition & 0 deletions pkg/crosscluster/logical/batch_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ func testBatchHandlerExhaustive(t *testing.T, factory batchHandlerFactory) {
}

handler, desc := factory(t, s, "test_table")
defer handler.Close(ctx)
defer handler.ReleaseLeases(ctx)

// TODO(jeffswenson): test the other handler types.
Expand Down
3 changes: 3 additions & 0 deletions pkg/crosscluster/logical/logical_replication_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,9 @@ func TestRandomTables(t *testing.T) {
tc, s, runnerA, runnerB := setupLogicalTestServer(t, ctx, testClusterBaseClusterArgs, 1)
defer tc.Stopper().Stop(ctx)

// TODO(#148303): Remove this once the crud writer supports tables with array primary keys.
runnerA.Exec(t, "SET CLUSTER SETTING logical_replication.consumer.immediate_mode_writer = 'legacy-kv'")

sqlA := s.SQLConn(t, serverutils.DBName("a"))

var tableName, streamStartStmt string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ type logicalReplicationWriterProcessor struct {

purgatory purgatory

seenKeys map[uint64]int64
dupeCount int64
seenEvery log.EveryN
seenKeys map[uint64]int64
dupeCount int64
seenEvery log.EveryN
retryEvery log.EveryN
}

var (
Expand Down Expand Up @@ -224,9 +225,10 @@ func newLogicalReplicationWriterProcessor(
StreamID: streampb.StreamID(spec.StreamID),
ProcessorID: processorID,
},
dlqClient: InitDeadLetterQueueClient(dlqDbExec, destTableBySrcID),
metrics: flowCtx.Cfg.JobRegistry.MetricsStruct().JobSpecificMetrics[jobspb.TypeLogicalReplication].(*Metrics),
seenEvery: log.Every(1 * time.Minute),
dlqClient: InitDeadLetterQueueClient(dlqDbExec, destTableBySrcID),
metrics: flowCtx.Cfg.JobRegistry.MetricsStruct().JobSpecificMetrics[jobspb.TypeLogicalReplication].(*Metrics),
seenEvery: log.Every(1 * time.Minute),
retryEvery: log.Every(1 * time.Minute),
}
lrw.purgatory = purgatory{
deadline: func() time.Duration { return retryQueueAgeLimit.Get(&flowCtx.Cfg.Settings.SV) },
Expand Down Expand Up @@ -1039,6 +1041,9 @@ func (lrw *logicalReplicationWriterProcessor) flushChunk(
}
stats.processed.dlq++
} else {
if lrw.retryEvery.ShouldLog() {
log.Dev.Warningf(ctx, "retrying failed apply: %+v", err)
}
stats.notProcessed.count++
stats.notProcessed.bytes += int64(batch[i].Size())
}
Expand Down
82 changes: 62 additions & 20 deletions pkg/crosscluster/logical/replication_statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

type columnSchema struct {
column catalog.Column
columnType *types.T
isPrimaryKey bool
isComputed bool
}
Expand Down Expand Up @@ -58,6 +59,7 @@ func getColumnSchema(table catalog.TableDescriptor) []columnSchema {

result = append(result, columnSchema{
column: col,
columnType: col.GetType().Canonical(),
isPrimaryKey: isPrimaryKey[col.GetID()],
isComputed: isComputed,
})
Expand Down Expand Up @@ -86,12 +88,14 @@ func newTypedPlaceholder(idx int, col catalog.Column) (*tree.CastExpr, error) {
// in the table. Parameters are ordered by column ID.
func newInsertStatement(
table catalog.TableDescriptor,
) (statements.Statement[tree.Statement], error) {
) (statements.Statement[tree.Statement], []*types.T, error) {
columns := getColumnSchema(table)

columnNames := make(tree.NameList, 0, len(columns))
parameters := make(tree.Exprs, 0, len(columns))
paramTypes := make([]*types.T, 0, len(columns))
for i, col := range columns {
paramTypes = append(paramTypes, col.columnType)

// NOTE: this consumes a placholder ID because its part of the tree.Datums,
// but it doesn't show up in the query because computed columns are not
// needed for insert statements.
Expand All @@ -102,7 +106,7 @@ func newInsertStatement(
var err error
parameter, err := newTypedPlaceholder(i+1, col.column)
if err != nil {
return statements.Statement[tree.Statement]{}, err
return statements.Statement[tree.Statement]{}, nil, err
}

columnNames = append(columnNames, tree.Name(col.column.GetName()))
Expand All @@ -129,7 +133,11 @@ func newInsertStatement(
Returning: tree.AbsentReturningClause,
}

return toParsedStatement(insert)
stmt, err := toParsedStatement(insert)
if err != nil {
return statements.Statement[tree.Statement]{}, nil, err
}
return stmt, paramTypes, nil
}

// newMatchesLastRow creates a WHERE clause for matching all columns of a row.
Expand Down Expand Up @@ -183,16 +191,16 @@ func newMatchesLastRow(columns []columnSchema, startParamIdx int) (tree.Expr, er
// Parameters are ordered by column ID.
func newUpdateStatement(
table catalog.TableDescriptor,
) (statements.Statement[tree.Statement], error) {
) (statements.Statement[tree.Statement], []*types.T, error) {
columns := getColumnSchema(table)

// Create WHERE clause for matching the previous row values
whereClause, err := newMatchesLastRow(columns, 1)
if err != nil {
return statements.Statement[tree.Statement]{}, err
return statements.Statement[tree.Statement]{}, nil, err
}

exprs := make(tree.UpdateExprs, 0, len(columns))
paramTypes := make([]*types.T, 0, 2*len(columns))
for i, col := range columns {
if col.isComputed {
// Skip computed columns since they are not needed to fully specify the
Expand All @@ -208,7 +216,7 @@ func newUpdateStatement(
// are for the where clause.
placeholder, err := newTypedPlaceholder(len(columns)+i+1, col.column)
if err != nil {
return statements.Statement[tree.Statement]{}, err
return statements.Statement[tree.Statement]{}, nil, err
}

exprs = append(exprs, &tree.UpdateExpr{
Expand All @@ -217,6 +225,15 @@ func newUpdateStatement(
})
}

// Add parameter types for WHERE clause (previous values)
for _, col := range columns {
paramTypes = append(paramTypes, col.columnType)
}
// Add parameter types for SET clause (new values)
for _, col := range columns {
paramTypes = append(paramTypes, col.columnType)
}

// Create the final update statement
update := &tree.Update{
Table: &tree.TableRef{
Expand All @@ -228,7 +245,12 @@ func newUpdateStatement(
Returning: tree.AbsentReturningClause,
}

return toParsedStatement(update)
stmt, err := toParsedStatement(update)
if err != nil {
return statements.Statement[tree.Statement]{}, nil, err
}

return stmt, paramTypes, nil
}

// newDeleteStatement returns a statement that can be used to delete a row from
Expand All @@ -239,13 +261,19 @@ func newUpdateStatement(
// Parameters are ordered by column ID.
func newDeleteStatement(
table catalog.TableDescriptor,
) (statements.Statement[tree.Statement], error) {
) (statements.Statement[tree.Statement], []*types.T, error) {
columns := getColumnSchema(table)

// Create WHERE clause for matching the row to delete
whereClause, err := newMatchesLastRow(columns, 1)
if err != nil {
return statements.Statement[tree.Statement]{}, err
return statements.Statement[tree.Statement]{}, nil, err
}

// Create parameter types for WHERE clause
paramTypes := make([]*types.T, 0, len(columns))
for _, col := range columns {
paramTypes = append(paramTypes, col.columnType)
}

// Create the final delete statement
Expand All @@ -261,7 +289,11 @@ func newDeleteStatement(
Returning: &tree.ReturningExprs{tree.StarSelectExpr()},
}

return toParsedStatement(delete)
stmt, err := toParsedStatement(delete)
if err != nil {
return statements.Statement[tree.Statement]{}, nil, err
}
return stmt, paramTypes, nil
}

// newBulkSelectStatement returns a statement that can be used to query
Expand Down Expand Up @@ -289,26 +321,32 @@ func newDeleteStatement(
// AND replication_target.secondary_id = key_list.key2
func newBulkSelectStatement(
table catalog.TableDescriptor,
) (statements.Statement[tree.Statement], error) {
) (statements.Statement[tree.Statement], []*types.T, error) {
cols := getColumnSchema(table)
primaryKeyColumns := make([]catalog.Column, 0, len(cols))
primaryKeyColumns := make([]columnSchema, 0, len(cols))
for _, col := range cols {
if col.isPrimaryKey {
primaryKeyColumns = append(primaryKeyColumns, col.column)
primaryKeyColumns = append(primaryKeyColumns, col)
}
}

// Create parameter types for primary key arrays
paramTypes := make([]*types.T, 0, len(primaryKeyColumns))
for _, pkCol := range primaryKeyColumns {
paramTypes = append(paramTypes, types.MakeArray(pkCol.columnType))
}

// keyListName is the name of the CTE that contains the primary keys supplied
// via array parameters.
keyListName, err := tree.NewUnresolvedObjectName(1, [3]string{"key_list"}, tree.NoAnnotation)
if err != nil {
return statements.Statement[tree.Statement]{}, err
return statements.Statement[tree.Statement]{}, nil, err
}

// targetName is used to name the user's table.
targetName, err := tree.NewUnresolvedObjectName(1, [3]string{"replication_target"}, tree.NoAnnotation)
if err != nil {
return statements.Statement[tree.Statement]{}, err
return statements.Statement[tree.Statement]{}, nil, err
}

// Create the `SELECT unnest($1::[]INT, $2::[]INT) WITH ORDINALITY AS key_list(key1, key2, index)` table expression.
Expand All @@ -320,7 +358,7 @@ func newBulkSelectStatement(
})
primaryKeyExprs = append(primaryKeyExprs, &tree.CastExpr{
Expr: &tree.Placeholder{Idx: tree.PlaceholderIdx(i)},
Type: types.MakeArray(pkCol.GetType()),
Type: types.MakeArray(pkCol.columnType),
SyntaxMode: tree.CastShort,
})
}
Expand Down Expand Up @@ -379,7 +417,7 @@ func newBulkSelectStatement(
// Construct the JOIN clause for the final query.
var joinCond tree.Expr
for i, pkCol := range primaryKeyColumns {
colName := tree.Name(pkCol.GetName())
colName := tree.Name(pkCol.column.GetName())
keyColName := fmt.Sprintf("key%d", i+1)

eqExpr := &tree.ComparisonExpr{
Expand Down Expand Up @@ -430,7 +468,11 @@ func newBulkSelectStatement(
},
}

return toParsedStatement(selectStmt)
stmt, err := toParsedStatement(selectStmt)
if err != nil {
return statements.Statement[tree.Statement]{}, nil, err
}
return stmt, paramTypes, nil
}

func toParsedStatement(stmt tree.Statement) (statements.Statement[tree.Statement], error) {
Expand Down
34 changes: 7 additions & 27 deletions pkg/crosscluster/logical/replication_statements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
gosql "database/sql"
"fmt"
"math/rand"
"slices"
"testing"

"github.com/cockroachdb/cockroach/pkg/base"
Expand Down Expand Up @@ -68,16 +67,7 @@ func TestReplicationStatements(t *testing.T) {

asSql := tree.Serialize(&p)
_, err := db.Exec(asSql)
require.NoError(t, err)
}

getTypes := func(desc catalog.TableDescriptor) []*types.T {
columns := getColumnSchema(desc)
types := make([]*types.T, len(columns))
for i, col := range columns {
types[i] = col.column.GetType()
}
return types
require.NoError(t, err, "statement: %s", asSql)
}

datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) {
Expand All @@ -95,10 +85,10 @@ func TestReplicationStatements(t *testing.T) {

desc := getTableDesc(tableName)

insertStmt, err := newInsertStatement(desc)
insertStmt, types, err := newInsertStatement(desc)
require.NoError(t, err)

prepareStatement(t, sqlDB, getTypes(desc), insertStmt)
prepareStatement(t, sqlDB, types, insertStmt)

return insertStmt.SQL
case "show-update":
Expand All @@ -107,12 +97,11 @@ func TestReplicationStatements(t *testing.T) {

desc := getTableDesc(tableName)

updateStmt, err := newUpdateStatement(desc)
updateStmt, types, err := newUpdateStatement(desc)
require.NoError(t, err)

// update expects previous and current values to be passed as
// parameters.
types := slices.Concat(getTypes(desc), getTypes(desc))
prepareStatement(t, sqlDB, types, updateStmt)

return updateStmt.SQL
Expand All @@ -124,10 +113,9 @@ func TestReplicationStatements(t *testing.T) {

// delete expects previous and current values to be passed as
// parameters.
deleteStmt, err := newDeleteStatement(desc)
deleteStmt, types, err := newDeleteStatement(desc)
require.NoError(t, err)

types := slices.Concat(getTypes(desc), getTypes(desc))
prepareStatement(t, sqlDB, types, deleteStmt)

return deleteStmt.SQL
Expand All @@ -137,18 +125,10 @@ func TestReplicationStatements(t *testing.T) {

desc := getTableDesc(tableName)

stmt, err := newBulkSelectStatement(desc)
stmt, types, err := newBulkSelectStatement(desc)
require.NoError(t, err)

allColumns := getColumnSchema(desc)
var primaryKeyTypes []*types.T
for _, col := range allColumns {
if col.isPrimaryKey {
primaryKeyTypes = append(primaryKeyTypes, types.MakeArray(col.column.GetType()))
}
}

prepareStatement(t, sqlDB, primaryKeyTypes, stmt)
prepareStatement(t, sqlDB, types, stmt)

return stmt.SQL
default:
Expand Down
Loading