Skip to content

Commit 1341c39

Browse files
SNOW-978164: Fix stmt.Exec for DML (#978)
1 parent d56c0f2 commit 1341c39

5 files changed

+158
-12
lines changed

connection.go

+16
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ const (
6060
queryResultType resultType = "query"
6161
)
6262

63+
type execKey string
64+
65+
const (
66+
executionType execKey = "executionType"
67+
executionTypeStatement string = "statement"
68+
)
69+
6370
const privateLinkSuffix = "privatelink.snowflakecomputing.com"
6471

6572
type snowflakeConn struct {
@@ -333,8 +340,17 @@ func (sc *snowflakeConn) ExecContext(
333340
}, nil // last insert id is not supported by Snowflake
334341
} else if isMultiStmt(&data.Data) {
335342
return sc.handleMultiExec(ctx, data.Data)
343+
} else if isDql(&data.Data) {
344+
logger.WithContext(ctx).Debugf("DQL")
345+
if isStatementContext(ctx) {
346+
return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
347+
}
348+
return driver.ResultNoRows, nil
336349
}
337350
logger.Debug("DDL")
351+
if isStatementContext(ctx) {
352+
return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil
353+
}
338354
return driver.ResultNoRows, nil
339355
}
340356

connection_util.go

+12
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,14 @@ func isDml(v int64) bool {
197197
return statementTypeIDDml <= v && v <= statementTypeIDMultiTableInsert
198198
}
199199

200+
func isDql(data *execResponseData) bool {
201+
return data.StatementTypeID == statementTypeIDSelect && !isMultiStmt(data)
202+
}
203+
200204
func updateRows(data execResponseData) (int64, error) {
205+
if data.RowSet == nil {
206+
return 0, nil
207+
}
201208
var count int64
202209
for i, n := 0, len(data.RowType); i < n; i++ {
203210
v, err := strconv.ParseInt(*data.RowSet[0][i], 10, 64)
@@ -292,3 +299,8 @@ func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error {
292299
}
293300
return nil
294301
}
302+
303+
func isStatementContext(ctx context.Context) bool {
304+
v := ctx.Value(executionType)
305+
return v == executionTypeStatement
306+
}

result.go

+18
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
package gosnowflake
44

5+
import "errors"
6+
57
type queryStatus string
68

79
const (
@@ -73,3 +75,19 @@ func (res *snowflakeResult) waitForAsyncExecStatus() error {
7375
}
7476
return nil
7577
}
78+
79+
type snowflakeResultNoRows struct {
80+
queryID string
81+
}
82+
83+
func (*snowflakeResultNoRows) LastInsertId() (int64, error) {
84+
return 0, errors.New("no LastInsertId available")
85+
}
86+
87+
func (*snowflakeResultNoRows) RowsAffected() (int64, error) {
88+
return 0, errors.New("no RowsAffected available")
89+
}
90+
91+
func (rnr *snowflakeResultNoRows) GetQueryID() string {
92+
return rnr.queryID
93+
}

statement.go

+16-12
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,7 @@ func (stmt *snowflakeStmt) NumInput() int {
3434

3535
func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
3636
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext")
37-
result, err := stmt.sc.ExecContext(ctx, stmt.query, args)
38-
if err != nil {
39-
stmt.setQueryIDFromError(err)
40-
return nil, err
41-
}
42-
r, ok := result.(SnowflakeResult)
43-
if !ok {
44-
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
45-
}
46-
stmt.lastQueryID = r.GetQueryID()
47-
return result, err
37+
return stmt.execInternal(ctx, args)
4838
}
4939

5040
func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
@@ -64,11 +54,25 @@ func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.Named
6454

6555
func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) {
6656
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec")
67-
result, err := stmt.sc.Exec(stmt.query, args)
57+
return stmt.execInternal(context.Background(), toNamedValues(args))
58+
}
59+
60+
func (stmt *snowflakeStmt) execInternal(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
61+
logger.WithContext(stmt.sc.ctx).Debugln("Stmt.execInternal")
62+
if ctx == nil {
63+
ctx = context.Background()
64+
}
65+
stmtCtx := context.WithValue(ctx, executionType, executionTypeStatement)
66+
result, err := stmt.sc.ExecContext(stmtCtx, stmt.query, args)
6867
if err != nil {
6968
stmt.setQueryIDFromError(err)
7069
return nil, err
7170
}
71+
rnr, ok := result.(*snowflakeResultNoRows)
72+
if ok {
73+
stmt.lastQueryID = rnr.GetQueryID()
74+
return driver.ResultNoRows, nil
75+
}
7276
r, ok := result.(SnowflakeResult)
7377
if !ok {
7478
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)

statement_test.go

+96
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,102 @@ func openConn(t *testing.T) *sql.Conn {
4040
return conn
4141
}
4242

43+
func TestExecStmt(t *testing.T) {
44+
dqlQuery := "SELECT 1"
45+
dmlQuery := "INSERT INTO TestDDLExec VALUES (1)"
46+
ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)"
47+
multiStmtQuery := "DELETE FROM TestDDLExec;\n" +
48+
"SELECT 1;\n" +
49+
"SELECT 2;"
50+
ctx := context.Background()
51+
multiStmtCtx, err := WithMultiStatement(ctx, 3)
52+
if err != nil {
53+
t.Error(err)
54+
}
55+
runDBTest(t, func(dbt *DBTest) {
56+
dbt.mustExec(ddlQuery)
57+
defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec")
58+
testcases := []struct {
59+
name string
60+
query string
61+
f func(stmt driver.Stmt) (any, error)
62+
}{
63+
{
64+
name: "dql Exec",
65+
query: dqlQuery,
66+
f: func(stmt driver.Stmt) (any, error) {
67+
return stmt.Exec(nil)
68+
},
69+
},
70+
{
71+
name: "dql ExecContext",
72+
query: dqlQuery,
73+
f: func(stmt driver.Stmt) (any, error) {
74+
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
75+
},
76+
},
77+
{
78+
name: "ddl Exec",
79+
query: ddlQuery,
80+
f: func(stmt driver.Stmt) (any, error) {
81+
return stmt.Exec(nil)
82+
},
83+
},
84+
{
85+
name: "ddl ExecContext",
86+
query: ddlQuery,
87+
f: func(stmt driver.Stmt) (any, error) {
88+
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
89+
},
90+
},
91+
{
92+
name: "dml Exec",
93+
query: dmlQuery,
94+
f: func(stmt driver.Stmt) (any, error) {
95+
return stmt.Exec(nil)
96+
},
97+
},
98+
{
99+
name: "dml ExecContext",
100+
query: dmlQuery,
101+
f: func(stmt driver.Stmt) (any, error) {
102+
return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
103+
},
104+
},
105+
{
106+
name: "multistmt ExecContext",
107+
query: multiStmtQuery,
108+
f: func(stmt driver.Stmt) (any, error) {
109+
return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil)
110+
},
111+
},
112+
}
113+
for _, tc := range testcases {
114+
t.Run(tc.name, func(t *testing.T) {
115+
err := dbt.conn.Raw(func(x any) error {
116+
stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
117+
if err != nil {
118+
t.Error(err)
119+
}
120+
if stmt.(SnowflakeStmt).GetQueryID() != "" {
121+
t.Error("queryId should be empty before executing any query")
122+
}
123+
if _, err := tc.f(stmt); err != nil {
124+
t.Errorf("should have not failed to execute the query, err: %s\n", err)
125+
}
126+
if stmt.(SnowflakeStmt).GetQueryID() == "" {
127+
t.Error("should have set the query id")
128+
}
129+
return nil
130+
})
131+
if err != nil {
132+
t.Fatal(err)
133+
}
134+
})
135+
}
136+
})
137+
}
138+
43139
func TestFailedQueryIdInSnowflakeError(t *testing.T) {
44140
failingQuery := "SELECTT 1"
45141
failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"

0 commit comments

Comments
 (0)