Skip to content

Commit 257f8c3

Browse files
committed
* Supported json.Marshaller query parameter in database/sql driver
1 parent 4ebb2e4 commit 257f8c3

File tree

4 files changed

+164
-1
lines changed

4 files changed

+164
-1
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Supported `json.Marshaller` query parameter in `database/sql` driver
2+
13
## v3.108.0
24
* Added `query.EmptyTxControl()` for empty transaction control (server-side defines transaction control by internal logic)
35
* Marked as deprecated `query.NoTx()` because this is wrong name for server-side transaction control inference

internal/bind/params.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package bind
33
import (
44
"database/sql"
55
"database/sql/driver"
6+
"encoding/json"
67
"errors"
78
"fmt"
89
"net/url"
@@ -16,6 +17,7 @@ import (
1617
"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
1718
"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
1819
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
20+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring"
1921
)
2022

2123
var (
@@ -164,10 +166,19 @@ func toValue(v any) (_ value.Value, err error) {
164166
if valuer, ok := v.(driver.Valuer); ok {
165167
v, err = valuer.Value()
166168
if err != nil {
167-
return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err)
169+
return nil, xerrors.WithStackTrace(fmt.Errorf("driver.Valuer error: %w", err))
168170
}
169171
}
170172

173+
if json, ok := v.(json.Marshaler); ok {
174+
bytes, err := json.MarshalJSON()
175+
if err != nil {
176+
return nil, xerrors.WithStackTrace(fmt.Errorf("json.Marshaller error: %w", err))
177+
}
178+
179+
return value.JSONValue(xstring.FromBytes(bytes)), err
180+
}
181+
171182
if x, ok := asUUID(v); ok {
172183
return x, nil
173184
}
@@ -255,6 +266,13 @@ func toValue(v any) (_ value.Value, err error) {
255266
return value.TimestampValueFromTime(x), nil
256267
case time.Duration:
257268
return value.IntervalValueFromDuration(x), nil
269+
case json.Marshaler:
270+
bytes, err := x.MarshalJSON()
271+
if err != nil {
272+
return nil, xerrors.WithStackTrace(err)
273+
}
274+
275+
return value.JSONValue(xstring.FromBytes(bytes)), nil
258276
default:
259277
kind := reflect.TypeOf(x).Kind()
260278
switch kind {
@@ -301,6 +319,18 @@ func toValue(v any) (_ value.Value, err error) {
301319
case reflect.Struct:
302320
v := reflect.ValueOf(x)
303321

322+
if v.CanAddr() {
323+
addr := v.Addr()
324+
if x, has := addr.Interface().(json.Marshaler); has {
325+
bytes, err := x.MarshalJSON()
326+
if err != nil {
327+
return nil, xerrors.WithStackTrace(err)
328+
}
329+
330+
return value.JSONValue(xstring.FromBytes(bytes)), nil
331+
}
332+
}
333+
304334
fields := make([]value.StructValueField, v.NumField())
305335

306336
for i := range fields {

internal/bind/positional_args.go

+4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ func (m PositionalArgs) ToYdb(sql string, args ...any) (
5757
}
5858
}
5959

60+
if position == 0 {
61+
return sql, args, nil
62+
}
63+
6064
if len(args) != position {
6165
return "", nil, xerrors.WithStackTrace(
6266
fmt.Errorf("%w: (positional args %d, query args %d)", ErrInconsistentArgs, position, len(args)),
+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//go:build integration
2+
// +build integration
3+
4+
package integration
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/ydb-platform/ydb-go-sdk/v3"
14+
"github.com/ydb-platform/ydb-go-sdk/v3/table"
15+
"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
16+
)
17+
18+
var (
19+
_ json.Marshaler = (*testJson)(nil)
20+
)
21+
22+
type testJson struct {
23+
a int64 `json:"a"`
24+
b string `json:"b"`
25+
}
26+
27+
func (t testJson) MarshalJSON() ([]byte, error) {
28+
return []byte(fmt.Sprintf(`{"a":%d,"b":"%s"}`, t.a, t.b)), nil
29+
}
30+
31+
func TestDatabaseSqlJson(t *testing.T) {
32+
var (
33+
scope = newScope(t)
34+
db = scope.SQLDriver(
35+
ydb.WithQueryService(true),
36+
ydb.WithPositionalArgs(),
37+
)
38+
)
39+
40+
t.Run("table/types", func(t *testing.T) {
41+
t.Run("named param", func(t *testing.T) {
42+
t.Run("check ydb type", func(t *testing.T) {
43+
row := db.QueryRowContext(scope.Ctx, "SELECT FormatType(TypeOf($a))",
44+
table.ValueParam("$a", types.JSONDocumentValue(`{"a":1,"b":"2"}`)),
45+
)
46+
var act string
47+
require.NoError(t, row.Scan(&act))
48+
require.NoError(t, row.Err())
49+
require.Equal(t, `JsonDocument`, act)
50+
})
51+
t.Run("get json value", func(t *testing.T) {
52+
row := db.QueryRowContext(scope.Ctx, "SELECT $a",
53+
table.ValueParam("$a", types.JSONDocumentValue(`{"a":1,"b":"2"}`)),
54+
)
55+
var act string
56+
require.NoError(t, row.Scan(&act))
57+
require.NoError(t, row.Err())
58+
require.Equal(t, `{"a":1,"b":"2"}`, act)
59+
})
60+
})
61+
t.Run("unnamed param", func(t *testing.T) {
62+
t.Run("check ydb type", func(t *testing.T) {
63+
row := db.QueryRowContext(scope.Ctx, "SELECT FormatType(TypeOf(?))",
64+
types.JSONDocumentValue(`{"a":1,"b":"2"}`),
65+
)
66+
var act string
67+
require.NoError(t, row.Scan(&act))
68+
require.NoError(t, row.Err())
69+
require.Equal(t, `JsonDocument`, act)
70+
})
71+
t.Run("get json value", func(t *testing.T) {
72+
row := db.QueryRowContext(scope.Ctx, "SELECT ?",
73+
types.JSONDocumentValue(`{"a":1,"b":"2"}`),
74+
)
75+
var act string
76+
require.NoError(t, row.Scan(&act))
77+
require.NoError(t, row.Err())
78+
require.Equal(t, `{"a":1,"b":"2"}`, act)
79+
})
80+
})
81+
})
82+
83+
t.Run("ydb.ParamsBuilder()", func(t *testing.T) {
84+
t.Run("check ydb type", func(t *testing.T) {
85+
row := db.QueryRowContext(scope.Ctx, "SELECT FormatType(TypeOf($a))",
86+
ydb.ParamsBuilder().Param("$a").JSON(`{"a":1,"b":"2"}`).Build(),
87+
)
88+
var act string
89+
require.NoError(t, row.Scan(&act))
90+
require.NoError(t, row.Err())
91+
require.Equal(t, `Json`, act)
92+
})
93+
t.Run("get json value", func(t *testing.T) {
94+
row := db.QueryRowContext(scope.Ctx, "SELECT $a",
95+
ydb.ParamsBuilder().Param("$a").JSON(`{"a":1,"b":"2"}`).Build(),
96+
)
97+
var act string
98+
require.NoError(t, row.Scan(&act))
99+
require.NoError(t, row.Err())
100+
require.Equal(t, `{"a":1,"b":"2"}`, act)
101+
})
102+
})
103+
104+
t.Run("json.Marshaler", func(t *testing.T) {
105+
t.Run("check ydb type", func(t *testing.T) {
106+
row := db.QueryRowContext(scope.Ctx, "SELECT FormatType(TypeOf(?))", testJson{a: 1, b: "2"})
107+
var act string
108+
require.NoError(t, row.Scan(&act))
109+
require.NoError(t, row.Err())
110+
require.Equal(t, `Json`, act)
111+
})
112+
t.Run("struct param", func(t *testing.T) {
113+
row := db.QueryRowContext(scope.Ctx, "SELECT ?", testJson{a: 1, b: "2"})
114+
var act string
115+
require.NoError(t, row.Scan(&act))
116+
require.NoError(t, row.Err())
117+
require.Equal(t, `{"a":1,"b":"2"}`, act)
118+
})
119+
t.Run("pointer to struct param", func(t *testing.T) {
120+
row := db.QueryRowContext(scope.Ctx, "SELECT ?", &testJson{a: 1, b: "2"})
121+
var act string
122+
require.NoError(t, row.Scan(&act))
123+
require.NoError(t, row.Err())
124+
require.Equal(t, `{"a":1,"b":"2"}`, act)
125+
})
126+
})
127+
}

0 commit comments

Comments
 (0)