Skip to content

Commit f007863

Browse files
authored
feat: add batch request support (#99)
Add batch support
1 parent 1fbbed8 commit f007863

File tree

3 files changed

+91
-20
lines changed

3 files changed

+91
-20
lines changed

handler.go

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
191191
cb(w)
192192
}
193193

194-
var req request
195194
// We read the entire request upfront in a buffer to be able to tell if the
196195
// client sent more than maxRequestSize and report it back as an explicit error,
197196
// instead of just silently truncating it and reporting a more vague parsing
@@ -205,11 +204,11 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
205204
if err != nil {
206205
// ReadFrom will discard EOF so any error here is unexpected and should
207206
// be reported.
208-
rpcError(wf, &req, rpcParseError, xerrors.Errorf("reading request: %w", err))
207+
rpcError(wf, nil, rpcParseError, xerrors.Errorf("reading request: %w", err))
209208
return
210209
}
211210
if reqSize > s.maxRequestSize {
212-
rpcError(wf, &req, rpcParseError,
211+
rpcError(wf, nil, rpcParseError,
213212
// rpcParseError is the closest we have from the standard errors defined
214213
// in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object)
215214
// to report the maximum limit.
@@ -218,17 +217,56 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
218217
return
219218
}
220219

221-
if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
222-
rpcError(wf, &req, rpcParseError, xerrors.Errorf("unmarshaling request: %w", err))
223-
return
224-
}
220+
// Trim spaces to avoid issues with batch request detection.
221+
bufferedRequest = bytes.NewBuffer(bytes.TrimSpace(bufferedRequest.Bytes()))
222+
reqSize = int64(bufferedRequest.Len())
225223

226-
if req.ID, err = normalizeID(req.ID); err != nil {
227-
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
224+
if reqSize == 0 {
225+
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
228226
return
229227
}
230228

231-
s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
229+
if bufferedRequest.Bytes()[0] == '[' && bufferedRequest.Bytes()[reqSize-1] == ']' {
230+
var reqs []request
231+
232+
if err := json.NewDecoder(bufferedRequest).Decode(&reqs); err != nil {
233+
rpcError(wf, nil, rpcParseError, xerrors.New("Parse error"))
234+
return
235+
}
236+
237+
if len(reqs) == 0 {
238+
rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request"))
239+
return
240+
}
241+
242+
w.Write([]byte("["))
243+
for idx, req := range reqs {
244+
if req.ID, err = normalizeID(req.ID); err != nil {
245+
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
246+
return
247+
}
248+
249+
s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
250+
251+
if idx != len(reqs)-1 {
252+
w.Write([]byte(","))
253+
}
254+
}
255+
w.Write([]byte("]"))
256+
} else {
257+
var req request
258+
if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
259+
rpcError(wf, &req, rpcParseError, xerrors.New("Parse error"))
260+
return
261+
}
262+
263+
if req.ID, err = normalizeID(req.ID); err != nil {
264+
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
265+
return
266+
}
267+
268+
s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
269+
}
232270
}
233271

234272
func doCall(methodName string, f reflect.Value, params []reflect.Value) (out []reflect.Value, err error) {

rpc_test.go

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,22 @@ func TestRawRequests(t *testing.T) {
9090
testServ := httptest.NewServer(rpcServer)
9191
defer testServ.Close()
9292

93-
tc := func(req, resp string, n int32) func(t *testing.T) {
93+
removeSpaces := func(jsonStr string) (string, error) {
94+
var jsonObj interface{}
95+
err := json.Unmarshal([]byte(jsonStr), &jsonObj)
96+
if err != nil {
97+
return "", err
98+
}
99+
100+
compactJSONBytes, err := json.Marshal(jsonObj)
101+
if err != nil {
102+
return "", err
103+
}
104+
105+
return string(compactJSONBytes), nil
106+
}
107+
108+
tc := func(req, resp string, n int32, statusCode int) func(t *testing.T) {
94109
return func(t *testing.T) {
95110
rpcHandler.n = 0
96111

@@ -100,16 +115,29 @@ func TestRawRequests(t *testing.T) {
100115
b, err := ioutil.ReadAll(res.Body)
101116
require.NoError(t, err)
102117

103-
assert.Equal(t, resp, strings.TrimSpace(string(b)))
118+
expectedResp, err := removeSpaces(resp)
119+
require.NoError(t, err)
120+
121+
responseBody, err := removeSpaces(string(b))
122+
require.NoError(t, err)
123+
124+
assert.Equal(t, expectedResp, responseBody)
104125
require.Equal(t, n, rpcHandler.n)
126+
require.Equal(t, statusCode, res.StatusCode)
105127
}
106128
}
107129

108-
t.Run("inc", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": [], "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1))
109-
t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1))
110-
t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2}`, 1))
111-
t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4}`, 10))
112-
130+
t.Run("inc", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": [], "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1, 200))
131+
t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1, 200))
132+
t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2}`, 1, 200))
133+
t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4}`, 10, 200))
134+
// Batch requests
135+
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 5}`, `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"Parse error"}}`, 0, 500))
136+
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 6}]`, `[{"jsonrpc":"2.0","id":6}]`, 123, 200))
137+
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 7},{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-122], "id": 8}]`, `[{"jsonrpc":"2.0","id":7},{"jsonrpc":"2.0","id":8}]`, 1, 200))
138+
t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 9},{"jsonrpc": "2.0", "params": [-122], "id": 10}]`, `[{"jsonrpc":"2.0","id":9},{"error":{"code":-32601,"message":"method '' not found"},"id":10,"jsonrpc":"2.0"}]`, 123, 200))
139+
t.Run("add", tc(` [{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-1], "id": 11}] `, `[{"jsonrpc":"2.0","id":11}]`, -1, 200))
140+
t.Run("add", tc(``, `{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid request"}}`, 0, 400))
113141
}
114142

115143
func TestReconnection(t *testing.T) {

server.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
const (
1717
rpcParseError = -32700
18+
rpcInvalidRequest = -32600
1819
rpcMethodNotFound = -32601
1920
rpcInvalidParams = -32602
2021
)
@@ -107,13 +108,17 @@ func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error)
107108
log.Errorf("RPC Error: %s", err)
108109
wf(func(w io.Writer) {
109110
if hw, ok := w.(http.ResponseWriter); ok {
110-
hw.WriteHeader(500)
111+
if code == rpcInvalidRequest {
112+
hw.WriteHeader(400)
113+
} else {
114+
hw.WriteHeader(500)
115+
}
111116
}
112117

113118
log.Warnf("rpc error: %s", err)
114119

115-
if req.ID == nil { // notification
116-
return
120+
if req == nil {
121+
req = &request{}
117122
}
118123

119124
resp := response{

0 commit comments

Comments
 (0)