@@ -19,6 +19,7 @@ import (
1919
2020type columnSchema struct {
2121 column catalog.Column
22+ columnType * types.T
2223 isPrimaryKey bool
2324 isComputed bool
2425}
@@ -58,6 +59,7 @@ func getColumnSchema(table catalog.TableDescriptor) []columnSchema {
5859
5960 result = append (result , columnSchema {
6061 column : col ,
62+ columnType : col .GetType ().Canonical (),
6163 isPrimaryKey : isPrimaryKey [col .GetID ()],
6264 isComputed : isComputed ,
6365 })
@@ -86,12 +88,14 @@ func newTypedPlaceholder(idx int, col catalog.Column) (*tree.CastExpr, error) {
8688// in the table. Parameters are ordered by column ID.
8789func newInsertStatement (
8890 table catalog.TableDescriptor ,
89- ) (statements.Statement [tree.Statement ], error ) {
91+ ) (statements.Statement [tree.Statement ], [] * types. T , error ) {
9092 columns := getColumnSchema (table )
91-
9293 columnNames := make (tree.NameList , 0 , len (columns ))
9394 parameters := make (tree.Exprs , 0 , len (columns ))
95+ paramTypes := make ([]* types.T , 0 , len (columns ))
9496 for i , col := range columns {
97+ paramTypes = append (paramTypes , col .columnType )
98+
9599 // NOTE: this consumes a placholder ID because its part of the tree.Datums,
96100 // but it doesn't show up in the query because computed columns are not
97101 // needed for insert statements.
@@ -102,7 +106,7 @@ func newInsertStatement(
102106 var err error
103107 parameter , err := newTypedPlaceholder (i + 1 , col .column )
104108 if err != nil {
105- return statements.Statement [tree.Statement ]{}, err
109+ return statements.Statement [tree.Statement ]{}, nil , err
106110 }
107111
108112 columnNames = append (columnNames , tree .Name (col .column .GetName ()))
@@ -129,7 +133,11 @@ func newInsertStatement(
129133 Returning : tree .AbsentReturningClause ,
130134 }
131135
132- return toParsedStatement (insert )
136+ stmt , err := toParsedStatement (insert )
137+ if err != nil {
138+ return statements.Statement [tree.Statement ]{}, nil , err
139+ }
140+ return stmt , paramTypes , nil
133141}
134142
135143// newMatchesLastRow creates a WHERE clause for matching all columns of a row.
@@ -183,16 +191,16 @@ func newMatchesLastRow(columns []columnSchema, startParamIdx int) (tree.Expr, er
183191// Parameters are ordered by column ID.
184192func newUpdateStatement (
185193 table catalog.TableDescriptor ,
186- ) (statements.Statement [tree.Statement ], error ) {
194+ ) (statements.Statement [tree.Statement ], [] * types. T , error ) {
187195 columns := getColumnSchema (table )
188-
189196 // Create WHERE clause for matching the previous row values
190197 whereClause , err := newMatchesLastRow (columns , 1 )
191198 if err != nil {
192- return statements.Statement [tree.Statement ]{}, err
199+ return statements.Statement [tree.Statement ]{}, nil , err
193200 }
194201
195202 exprs := make (tree.UpdateExprs , 0 , len (columns ))
203+ paramTypes := make ([]* types.T , 0 , 2 * len (columns ))
196204 for i , col := range columns {
197205 if col .isComputed {
198206 // Skip computed columns since they are not needed to fully specify the
@@ -208,7 +216,7 @@ func newUpdateStatement(
208216 // are for the where clause.
209217 placeholder , err := newTypedPlaceholder (len (columns )+ i + 1 , col .column )
210218 if err != nil {
211- return statements.Statement [tree.Statement ]{}, err
219+ return statements.Statement [tree.Statement ]{}, nil , err
212220 }
213221
214222 exprs = append (exprs , & tree.UpdateExpr {
@@ -217,6 +225,15 @@ func newUpdateStatement(
217225 })
218226 }
219227
228+ // Add parameter types for WHERE clause (previous values)
229+ for _ , col := range columns {
230+ paramTypes = append (paramTypes , col .columnType )
231+ }
232+ // Add parameter types for SET clause (new values)
233+ for _ , col := range columns {
234+ paramTypes = append (paramTypes , col .columnType )
235+ }
236+
220237 // Create the final update statement
221238 update := & tree.Update {
222239 Table : & tree.TableRef {
@@ -228,7 +245,12 @@ func newUpdateStatement(
228245 Returning : tree .AbsentReturningClause ,
229246 }
230247
231- return toParsedStatement (update )
248+ stmt , err := toParsedStatement (update )
249+ if err != nil {
250+ return statements.Statement [tree.Statement ]{}, nil , err
251+ }
252+
253+ return stmt , paramTypes , nil
232254}
233255
234256// newDeleteStatement returns a statement that can be used to delete a row from
@@ -239,13 +261,19 @@ func newUpdateStatement(
239261// Parameters are ordered by column ID.
240262func newDeleteStatement (
241263 table catalog.TableDescriptor ,
242- ) (statements.Statement [tree.Statement ], error ) {
264+ ) (statements.Statement [tree.Statement ], [] * types. T , error ) {
243265 columns := getColumnSchema (table )
244266
245267 // Create WHERE clause for matching the row to delete
246268 whereClause , err := newMatchesLastRow (columns , 1 )
247269 if err != nil {
248- return statements.Statement [tree.Statement ]{}, err
270+ return statements.Statement [tree.Statement ]{}, nil , err
271+ }
272+
273+ // Create parameter types for WHERE clause
274+ paramTypes := make ([]* types.T , 0 , len (columns ))
275+ for _ , col := range columns {
276+ paramTypes = append (paramTypes , col .columnType )
249277 }
250278
251279 // Create the final delete statement
@@ -261,7 +289,11 @@ func newDeleteStatement(
261289 Returning : & tree.ReturningExprs {tree .StarSelectExpr ()},
262290 }
263291
264- return toParsedStatement (delete )
292+ stmt , err := toParsedStatement (delete )
293+ if err != nil {
294+ return statements.Statement [tree.Statement ]{}, nil , err
295+ }
296+ return stmt , paramTypes , nil
265297}
266298
267299// newBulkSelectStatement returns a statement that can be used to query
@@ -289,26 +321,32 @@ func newDeleteStatement(
289321// AND replication_target.secondary_id = key_list.key2
290322func newBulkSelectStatement (
291323 table catalog.TableDescriptor ,
292- ) (statements.Statement [tree.Statement ], error ) {
324+ ) (statements.Statement [tree.Statement ], [] * types. T , error ) {
293325 cols := getColumnSchema (table )
294- primaryKeyColumns := make ([]catalog. Column , 0 , len (cols ))
326+ primaryKeyColumns := make ([]columnSchema , 0 , len (cols ))
295327 for _ , col := range cols {
296328 if col .isPrimaryKey {
297- primaryKeyColumns = append (primaryKeyColumns , col . column )
329+ primaryKeyColumns = append (primaryKeyColumns , col )
298330 }
299331 }
300332
333+ // Create parameter types for primary key arrays
334+ paramTypes := make ([]* types.T , 0 , len (primaryKeyColumns ))
335+ for _ , pkCol := range primaryKeyColumns {
336+ paramTypes = append (paramTypes , types .MakeArray (pkCol .columnType ))
337+ }
338+
301339 // keyListName is the name of the CTE that contains the primary keys supplied
302340 // via array parameters.
303341 keyListName , err := tree .NewUnresolvedObjectName (1 , [3 ]string {"key_list" }, tree .NoAnnotation )
304342 if err != nil {
305- return statements.Statement [tree.Statement ]{}, err
343+ return statements.Statement [tree.Statement ]{}, nil , err
306344 }
307345
308346 // targetName is used to name the user's table.
309347 targetName , err := tree .NewUnresolvedObjectName (1 , [3 ]string {"replication_target" }, tree .NoAnnotation )
310348 if err != nil {
311- return statements.Statement [tree.Statement ]{}, err
349+ return statements.Statement [tree.Statement ]{}, nil , err
312350 }
313351
314352 // Create the `SELECT unnest($1::[]INT, $2::[]INT) WITH ORDINALITY AS key_list(key1, key2, index)` table expression.
@@ -320,7 +358,7 @@ func newBulkSelectStatement(
320358 })
321359 primaryKeyExprs = append (primaryKeyExprs , & tree.CastExpr {
322360 Expr : & tree.Placeholder {Idx : tree .PlaceholderIdx (i )},
323- Type : types .MakeArray (pkCol .GetType () ),
361+ Type : types .MakeArray (pkCol .columnType ),
324362 SyntaxMode : tree .CastShort ,
325363 })
326364 }
@@ -379,7 +417,7 @@ func newBulkSelectStatement(
379417 // Construct the JOIN clause for the final query.
380418 var joinCond tree.Expr
381419 for i , pkCol := range primaryKeyColumns {
382- colName := tree .Name (pkCol .GetName ())
420+ colName := tree .Name (pkCol .column . GetName ())
383421 keyColName := fmt .Sprintf ("key%d" , i + 1 )
384422
385423 eqExpr := & tree.ComparisonExpr {
@@ -430,7 +468,11 @@ func newBulkSelectStatement(
430468 },
431469 }
432470
433- return toParsedStatement (selectStmt )
471+ stmt , err := toParsedStatement (selectStmt )
472+ if err != nil {
473+ return statements.Statement [tree.Statement ]{}, nil , err
474+ }
475+ return stmt , paramTypes , nil
434476}
435477
436478func toParsedStatement (stmt tree.Statement ) (statements.Statement [tree.Statement ], error ) {
0 commit comments