Skip to content

Commit 17a3577

Browse files
authored
Merge pull request #101 from filecoin-project/fix/races
Fix a number of races in websocket channel code paths
2 parents f007863 + 2fec0c7 commit 17a3577

File tree

4 files changed

+69
-18
lines changed

4 files changed

+69
-18
lines changed

.circleci/config.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ jobs:
1212
steps:
1313
- checkout
1414
- run: go test -v -timeout 10m ./...
15+
test-race:
16+
executor: golang
17+
steps:
18+
- checkout
19+
- run: go test -race -v -timeout 10m ./...
1520
mod-tidy-check:
1621
executor: golang
1722
steps:
@@ -42,3 +47,4 @@ workflows:
4247
- lint-check
4348
- gofmt-check
4449
- test
50+
- test-race

handler.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
239239
return
240240
}
241241

242-
w.Write([]byte("["))
242+
_, _ = w.Write([]byte("[")) // todo consider handling this error
243243
for idx, req := range reqs {
244244
if req.ID, err = normalizeID(req.ID); err != nil {
245245
rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err))
@@ -249,10 +249,10 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp
249249
s.handle(ctx, req, wf, rpcError, func(bool) {}, nil)
250250

251251
if idx != len(reqs)-1 {
252-
w.Write([]byte(","))
252+
_, _ = w.Write([]byte(",")) // todo consider handling this error
253253
}
254254
}
255-
w.Write([]byte("]"))
255+
_, _ = w.Write([]byte("]")) // todo consider handling this error
256256
} else {
257257
var req request
258258
if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {

rpc_test.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,16 @@ func TestReconnection(t *testing.T) {
160160
timer := time.NewTimer(captureDuration)
161161

162162
// record the number of connection attempts during this test
163-
connectionAttempts := 1
163+
connectionAttempts := int64(1)
164164

165165
closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", []interface{}{&rpcClient}, nil, func(c *Config) {
166166
c.proxyConnFactory = func(f func() (*websocket.Conn, error)) func() (*websocket.Conn, error) {
167167
return func() (*websocket.Conn, error) {
168168
defer func() {
169-
connectionAttempts++
169+
atomic.AddInt64(&connectionAttempts, 1)
170170
}()
171171

172-
if connectionAttempts > 1 {
172+
if atomic.LoadInt64(&connectionAttempts) > 1 {
173173
return nil, errors.New("simulates a failed reconnect attempt")
174174
}
175175

@@ -192,7 +192,7 @@ func TestReconnection(t *testing.T) {
192192
<-timer.C
193193

194194
// do some math
195-
attemptsPerSecond := int64(connectionAttempts) / int64(captureDuration/time.Second)
195+
attemptsPerSecond := atomic.LoadInt64(&connectionAttempts) / int64(captureDuration/time.Second)
196196

197197
assert.Less(t, attemptsPerSecond, int64(50))
198198
}
@@ -677,7 +677,7 @@ func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error
677677
fmt.Println("ctxdone1", i, eq)
678678
return
679679
case <-wait:
680-
fmt.Println("CONSUMED WAIT: ", i)
680+
//fmt.Println("CONSUMED WAIT: ", i)
681681
}
682682

683683
n += i
@@ -835,10 +835,11 @@ func TestChanServerClose(t *testing.T) {
835835

836836
tctx, tcancel := context.WithCancel(context.Background())
837837

838-
testServ := httptest.NewServer(rpcServer)
838+
testServ := httptest.NewUnstartedServer(rpcServer)
839839
testServ.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
840840
return tctx
841841
}
842+
testServ.Start()
842843

843844
closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil)
844845
require.NoError(t, err)
@@ -966,10 +967,11 @@ func TestChanClientReceiveAll(t *testing.T) {
966967

967968
tctx, tcancel := context.WithCancel(context.Background())
968969

969-
testServ := httptest.NewServer(rpcServer)
970+
testServ := httptest.NewUnstartedServer(rpcServer)
970971
testServ.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
971972
return tctx
972973
}
974+
testServ.Start()
973975

974976
closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil)
975977
require.NoError(t, err)
@@ -1005,6 +1007,11 @@ func TestChanClientReceiveAll(t *testing.T) {
10051007
}
10061008

10071009
func TestControlChanDeadlock(t *testing.T) {
1010+
_ = logging.SetLogLevel("rpc", "error")
1011+
defer func() {
1012+
_ = logging.SetLogLevel("rpc", "debug")
1013+
}()
1014+
10081015
for r := 0; r < 20; r++ {
10091016
testControlChanDeadlock(t)
10101017
}

websocket.go

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ type wsConn struct {
8282
inflightLk sync.Mutex
8383

8484
// chanHandlers is a map of client-side channel handlers
85-
chanHandlers map[uint64]func(m []byte, ok bool)
85+
chanHandlersLk sync.Mutex
86+
chanHandlers map[uint64]*chanHandler
8687

8788
// ////
8889
// Server related
@@ -99,6 +100,13 @@ type wsConn struct {
99100
registerCh chan outChanReg
100101
}
101102

103+
type chanHandler struct {
104+
// take inside chanHandlersLk
105+
lk sync.Mutex
106+
107+
cb func(m []byte, ok bool)
108+
}
109+
102110
// //
103111
// WebSocket Message utils //
104112
// //
@@ -367,13 +375,20 @@ func (c *wsConn) handleChanMessage(frame frame) {
367375
return
368376
}
369377

378+
c.chanHandlersLk.Lock()
370379
hnd, ok := c.chanHandlers[chid]
371380
if !ok {
381+
c.chanHandlersLk.Unlock()
372382
log.Errorf("xrpc.ch.val: handler %d not found", chid)
373383
return
374384
}
375385

376-
hnd(params[1].data, true)
386+
hnd.lk.Lock()
387+
defer hnd.lk.Unlock()
388+
389+
c.chanHandlersLk.Unlock()
390+
391+
hnd.cb(params[1].data, true)
377392
}
378393

379394
func (c *wsConn) handleChanClose(frame frame) {
@@ -389,15 +404,22 @@ func (c *wsConn) handleChanClose(frame frame) {
389404
return
390405
}
391406

407+
c.chanHandlersLk.Lock()
392408
hnd, ok := c.chanHandlers[chid]
393409
if !ok {
410+
c.chanHandlersLk.Unlock()
394411
log.Errorf("xrpc.ch.val: handler %d not found", chid)
395412
return
396413
}
397414

415+
hnd.lk.Lock()
416+
defer hnd.lk.Unlock()
417+
398418
delete(c.chanHandlers, chid)
399419

400-
hnd(nil, false)
420+
c.chanHandlersLk.Unlock()
421+
422+
hnd.cb(nil, false)
401423
}
402424

403425
func (c *wsConn) handleResponse(frame frame) {
@@ -417,8 +439,12 @@ func (c *wsConn) handleResponse(frame frame) {
417439
return
418440
}
419441

420-
var chanCtx context.Context
421-
chanCtx, c.chanHandlers[chid] = req.retCh()
442+
chanCtx, chHnd := req.retCh()
443+
444+
c.chanHandlersLk.Lock()
445+
c.chanHandlers[chid] = &chanHandler{cb: chHnd}
446+
c.chanHandlersLk.Unlock()
447+
422448
go c.handleCtxAsync(chanCtx, frame.ID)
423449
}
424450

@@ -517,16 +543,28 @@ func (c *wsConn) closeInFlight() {
517543
for _, cancel := range c.handling {
518544
cancel()
519545
}
546+
c.handling = map[interface{}]context.CancelFunc{}
520547
c.handlingLk.Unlock()
521548

522-
c.handling = map[interface{}]context.CancelFunc{}
523549
}
524550

525551
func (c *wsConn) closeChans() {
552+
c.chanHandlersLk.Lock()
553+
defer c.chanHandlersLk.Unlock()
554+
526555
for chid := range c.chanHandlers {
527556
hnd := c.chanHandlers[chid]
557+
558+
hnd.lk.Lock()
559+
528560
delete(c.chanHandlers, chid)
529-
hnd(nil, false)
561+
562+
c.chanHandlersLk.Unlock()
563+
564+
hnd.cb(nil, false)
565+
566+
hnd.lk.Unlock()
567+
c.chanHandlersLk.Lock()
530568
}
531569
}
532570

@@ -679,7 +717,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
679717
c.frameExecQueue = make(chan []byte, maxQueuedFrames)
680718
c.inflight = map[interface{}]clientRequest{}
681719
c.handling = map[interface{}]context.CancelFunc{}
682-
c.chanHandlers = map[uint64]func(m []byte, ok bool){}
720+
c.chanHandlers = map[uint64]*chanHandler{}
683721
c.pongs = make(chan struct{}, 1)
684722

685723
c.registerCh = make(chan outChanReg)

0 commit comments

Comments
 (0)