Skip to content

Commit dbc8833

Browse files
committed
* Supported json.Marshaller query parameter in database/sql driver
1 parent 4ebb2e4 commit dbc8833

File tree

4 files changed

+302
-14
lines changed

4 files changed

+302
-14
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Supported `json.Marshaller` query parameter in `database/sql` driver
2+
13
## v3.108.0
24
* Added `query.EmptyTxControl()` for empty transaction control (server-side defines transaction control by internal logic)
35
* Marked as deprecated `query.NoTx()` because this is wrong name for server-side transaction control inference

internal/bind/bind.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package bind
22

33
import (
4+
"database/sql/driver"
45
"sort"
56

67
"github.com/ydb-platform/ydb-go-sdk/v3/internal/params"
78
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
9+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xslices"
810
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring"
911
)
1012

@@ -29,15 +31,25 @@ type Bind interface {
2931
type Bindings []Bind
3032

3133
func (bindings Bindings) ToYdb(sql string, args ...any) (
32-
yql string, params params.Params, err error,
34+
yql string, pp params.Params, err error,
3335
) {
3436
if len(bindings) == 0 {
35-
params, err = Params(args...)
37+
pp, err = Params(args...)
3638
if err != nil {
3739
return "", nil, xerrors.WithStackTrace(err)
3840
}
3941

40-
return sql, params, nil
42+
return sql, pp, nil
43+
}
44+
45+
if len(args) == 1 {
46+
if nv, has := args[0].(driver.NamedValue); has {
47+
if pp, has := nv.Value.(*params.Params); has {
48+
args = xslices.Transform(*pp, func(v *params.Parameter) any {
49+
return v
50+
})
51+
}
52+
}
4153
}
4254

4355
buffer := xstring.Buffer()
@@ -51,12 +63,12 @@ func (bindings Bindings) ToYdb(sql string, args ...any) (
5163
}
5264
}
5365

54-
params, err = Params(args...)
66+
pp, err = Params(args...)
5567
if err != nil {
5668
return "", nil, xerrors.WithStackTrace(err)
5769
}
5870

59-
return sql, params, nil
71+
return sql, pp, nil
6072
}
6173

6274
func Sort(bindings []Bind) []Bind {

internal/bind/params.go

+19-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package bind
33
import (
44
"database/sql"
55
"database/sql/driver"
6+
"encoding/json"
67
"errors"
78
"fmt"
89
"net/url"
@@ -16,6 +17,7 @@ import (
1617
"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
1718
"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
1819
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
20+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring"
1921
)
2022

2123
var (
@@ -164,7 +166,7 @@ func toValue(v any) (_ value.Value, err error) {
164166
if valuer, ok := v.(driver.Valuer); ok {
165167
v, err = valuer.Value()
166168
if err != nil {
167-
return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err)
169+
return nil, xerrors.WithStackTrace(fmt.Errorf("driver.Valuer error: %w", err))
168170
}
169171
}
170172

@@ -255,6 +257,13 @@ func toValue(v any) (_ value.Value, err error) {
255257
return value.TimestampValueFromTime(x), nil
256258
case time.Duration:
257259
return value.IntervalValueFromDuration(x), nil
260+
case json.Marshaler:
261+
bytes, err := x.MarshalJSON()
262+
if err != nil {
263+
return nil, xerrors.WithStackTrace(err)
264+
}
265+
266+
return value.JSONValue(xstring.FromBytes(bytes)), nil
258267
default:
259268
kind := reflect.TypeOf(x).Kind()
260269
switch kind {
@@ -266,7 +275,7 @@ func toValue(v any) (_ value.Value, err error) {
266275
list[i], err = toValue(v.Index(i).Interface())
267276
if err != nil {
268277
return nil, xerrors.WithStackTrace(
269-
fmt.Errorf("cannot parse %d item of slice %T: %w",
278+
fmt.Errorf("cannot parse item #%d of slice %T: %w",
270279
i, x, err,
271280
),
272281
)
@@ -307,8 +316,8 @@ func toValue(v any) (_ value.Value, err error) {
307316
kk, has := v.Type().Field(i).Tag.Lookup("sql")
308317
if !has {
309318
return nil, xerrors.WithStackTrace(
310-
fmt.Errorf("cannot parse %q as key field of struct: %w",
311-
v.Type().Field(i).Name, errUnsupportedType,
319+
fmt.Errorf("cannot parse %q as key field of struct %T: %w",
320+
v.Type().Field(i).Name, x, errUnsupportedType,
312321
),
313322
)
314323
}
@@ -348,15 +357,16 @@ func supportNewTypeLink(x any) string {
348357
}
349358

350359
func toYdbParam(name string, value any) (*params.Parameter, error) {
351-
switch tv := value.(type) {
352-
case driver.NamedValue:
353-
n, v := tv.Name, tv.Value
360+
if nv, has := value.(driver.NamedValue); has {
361+
n, v := nv.Name, nv.Value
354362
if n != "" {
355363
name = n
356364
}
357365
value = v
358-
case *params.Parameter:
359-
return tv, nil
366+
}
367+
368+
if nv, ok := value.(params.NamedValue); ok {
369+
return params.Named(nv.Name(), nv.Value()), nil
360370
}
361371

362372
v, err := toValue(value)

0 commit comments

Comments
 (0)