Skip to content
Open
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Added support for `Result.RowsAffected()` for `database/sql`

## v3.117.1
* Fixed scan a column of type `Decimal(precision,scale)` into a struct field of type `types.Decimal{}` using `ScanStruct()`
* Fixed race in integration test `TestTopicWriterLogMessagesWithoutData`
Expand Down
40 changes: 35 additions & 5 deletions internal/xsql/xquery/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,38 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/common"
)

type resultNoRows struct{}
type resultWithStats struct {
stats stats.QueryStats
}

func (r *resultWithStats) LastInsertId() (int64, error) { return 0, ErrUnsupported }
func (r *resultWithStats) RowsAffected() (int64, error) {
if r.stats == nil {
return 0, ErrUnsupported
}

func (resultNoRows) LastInsertId() (int64, error) { return 0, ErrUnsupported }
func (resultNoRows) RowsAffected() (int64, error) { return 0, ErrUnsupported }
var rowsAffected uint64
for {
phase, ok := r.stats.NextPhase()
if !ok {
break
}

var _ driver.Result = resultNoRows{}
for {
tableAccess, ok := phase.NextTableAccess()
if !ok {
break
}

rowsAffected += tableAccess.Deletes.Rows
rowsAffected += tableAccess.Updates.Rows
}
}

return int64(rowsAffected), nil
}

var _ driver.Result = &resultWithStats{}

type Parent interface {
Query() *query.Client
Expand Down Expand Up @@ -55,8 +81,12 @@ func (c *Conn) Exec(ctx context.Context, sql string, params *params.Params) (
))
}

var st stats.QueryStats
opts := []options.Execute{
options.WithParameters(params),
options.WithStatsMode(options.StatsModeBasic, func(qs stats.QueryStats) {
st = qs
}),
}

if txControl := tx.ControlFromContext(ctx, nil); txControl != nil {
Expand All @@ -68,7 +98,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, params *params.Params) (
return nil, xerrors.WithStackTrace(err)
}

return resultNoRows{}, nil
return &resultWithStats{st}, nil
}

func (c *Conn) Query(ctx context.Context, sql string, params *params.Params) (
Expand Down
2 changes: 1 addition & 1 deletion internal/xsql/xquery/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (t *transaction) Exec(ctx context.Context, sql string, params *params.Param
return nil, xerrors.WithStackTrace(err)
}

return resultNoRows{}, nil
return &resultWithStats{}, nil
}

func (t *transaction) Query(ctx context.Context, sql string, params *params.Params) (driver.RowsNextResultSet, error) {
Expand Down
63 changes: 63 additions & 0 deletions tests/integration/database_sql_rows_affected_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//go:build integration
// +build integration

package integration

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ydb-platform/ydb-go-sdk/v3"
)

func TestDatabaseSQLRowsAffected(t *testing.T) {
tests := []struct {
sql string
rows int64
}{
{
sql: "INSERT INTO %s (id) values (1),(2),(3)",
rows: 3,
},
{
sql: "UPDATE %s SET val = 'test' where id > 1",
rows: 2,
},
{
sql: "DELETE FROM %s",
rows: 3,
},
{
sql: "INSERT INTO %s (id) values (1),(2),(3); INSERT INTO %[1]s (id) values (4),(5); ",
rows: 5,
},
{
sql: "UPDATE %s SET val = 'test' where id > 1; DELETE FROM %[1]s WHERE id < 4", // 4+3
rows: 7,
},
}

var (
scope = newScope(t)
db = scope.SQLDriverWithFolder(ydb.WithQueryService(true))
)

defer func() {
_ = db.Close()
}()

for _, test := range tests {
t.Run(test.sql, func(t *testing.T) {
result, err := db.Exec(fmt.Sprintf(test.sql, scope.TableName()))
require.NoError(t, err)

got, err := result.RowsAffected()
require.NoError(t, err)

assert.Equal(t, test.rows, got)
})
}
}
Loading