Skip to content

Commit d619379

Browse files
authored
fix(x/websocket): concurrent read/write panics (#497)
1 parent 48f606f commit d619379

File tree

2 files changed

+174
-1
lines changed

2 files changed

+174
-1
lines changed

x/websocket/endpoint.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"net/url"
2727
"runtime"
2828
"strings"
29+
"sync"
2930
"time"
3031

3132
"github.com/Jigsaw-Code/outline-sdk/transport"
@@ -111,7 +112,12 @@ func newGorillaConn(wsConn *websocket.Conn) *gorillaConn {
111112
}
112113

113114
type gorillaConn struct {
114-
wsConn *websocket.Conn
115+
wsConn *websocket.Conn
116+
117+
// websocket.Conn is not safe for concurrent use
118+
// https://github.com/Jigsaw-Code/outline-apps/issues/2573
119+
readMu, writeMu sync.Mutex
120+
115121
writeErr error
116122
readErr error
117123
pendingReader io.Reader
@@ -140,6 +146,9 @@ func (c *gorillaConn) SetWriteDeadline(deadline time.Time) error {
140146
}
141147

142148
func (c *gorillaConn) Read(buf []byte) (int, error) {
149+
c.readMu.Lock()
150+
defer c.readMu.Unlock()
151+
143152
if c.readErr != nil {
144153
return 0, c.readErr
145154
}
@@ -177,6 +186,9 @@ func (c *gorillaConn) Read(buf []byte) (int, error) {
177186
}
178187

179188
func (c *gorillaConn) Write(buf []byte) (int, error) {
189+
c.writeMu.Lock()
190+
defer c.writeMu.Unlock()
191+
180192
err := c.wsConn.WriteMessage(websocket.BinaryMessage, buf)
181193
if err != nil {
182194
if c.writeErr != nil {

x/websocket/endpoint_test.go

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,18 @@
1515
package websocket
1616

1717
import (
18+
"bufio"
1819
"context"
1920
"errors"
21+
"fmt"
2022
"io"
23+
"net"
2124
"net/http"
2225
"net/http/httptest"
26+
"slices"
27+
"strconv"
28+
"strings"
29+
"sync"
2330
"testing"
2431

2532
"github.com/Jigsaw-Code/outline-sdk/transport"
@@ -161,3 +168,157 @@ func Test_NewPacketEndpoint(t *testing.T) {
161168
require.NoError(t, err)
162169
require.Equal(t, []byte("Response"), buf[:n])
163170
}
171+
172+
// Test_ConcurrentWritePacket tests if gorillaConn can concurrently write packets.
173+
func Test_ConcurrentWritePacket(t *testing.T) {
174+
const numWrites = 100
175+
recved := make([]bool, numWrites)
176+
allFalse := make([]bool, numWrites)
177+
allTrue := slices.Repeat([]bool{true}, numWrites)
178+
var reqRecved sync.WaitGroup
179+
reqRecved.Add(numWrites)
180+
181+
ts, conn := setupAndConnectToTestUDPWebSocketServer(t, func(svrConn transport.StreamConn) {
182+
defer svrConn.Close()
183+
scan := bufio.NewScanner(svrConn)
184+
for scan.Scan() {
185+
nStr, ok := strings.CutPrefix(scan.Text(), "write-")
186+
require.True(t, ok)
187+
n, err := strconv.Atoi(nStr)
188+
require.NoError(t, err)
189+
if n > numWrites {
190+
break
191+
}
192+
recved[n] = true
193+
reqRecved.Done()
194+
}
195+
require.NoError(t, scan.Err())
196+
})
197+
defer ts.Close()
198+
199+
// Concurrenly writes "write-xxx\n" messages
200+
require.Equal(t, allFalse, recved)
201+
for i := range numWrites {
202+
go func() {
203+
_, err := fmt.Fprintf(conn, "write-%d\n", i)
204+
require.NoError(t, err)
205+
}()
206+
}
207+
reqRecved.Wait()
208+
require.NoError(t, conn.Close())
209+
require.Equal(t, allTrue, recved)
210+
}
211+
212+
// Test_ConcurrentCloseWritePacket tests if gorillaConn can concurrently be closed while writing.
213+
func Test_ConcurrentCloseWritePacket(t *testing.T) {
214+
t.Skip("TODO: figure out a good way to synchronize CloseWrite and Writes")
215+
216+
const numWrites = 100
217+
var writesDone sync.WaitGroup
218+
writesDone.Add(numWrites)
219+
220+
ts, conn := setupAndConnectToTestUDPWebSocketServer(t, func(svrConn transport.StreamConn) {
221+
writesDone.Wait()
222+
svrConn.Close()
223+
})
224+
defer ts.Close()
225+
226+
// Concurrently Close while writing
227+
for range numWrites {
228+
go func() {
229+
defer writesDone.Done()
230+
fmt.Fprintf(conn, "message\n")
231+
}()
232+
}
233+
require.NoError(t, conn.Close())
234+
writesDone.Wait()
235+
}
236+
237+
// Test_ConcurrentReadPacket tests if gorillaConn can concurrently receive packets.
238+
func Test_ConcurrentReadPacket(t *testing.T) {
239+
const numReads = 100
240+
recved := make([]bool, numReads)
241+
allFalse := make([]bool, numReads)
242+
allTrue := slices.Repeat([]bool{true}, numReads)
243+
var readsDone, testDone sync.WaitGroup
244+
readsDone.Add(numReads)
245+
testDone.Add(1)
246+
defer testDone.Done()
247+
248+
ts, conn := setupAndConnectToTestUDPWebSocketServer(t, func(svrConn transport.StreamConn) {
249+
defer svrConn.Close()
250+
for i := range numReads {
251+
_, err := fmt.Fprintf(svrConn, "read-%d\n", i)
252+
require.NoError(t, err)
253+
}
254+
testDone.Wait()
255+
})
256+
defer ts.Close()
257+
258+
// Concurrently reads "read-xxx\n" messages
259+
require.Equal(t, allFalse, recved)
260+
for range numReads {
261+
go func() {
262+
defer readsDone.Done()
263+
scan := bufio.NewScanner(conn)
264+
for scan.Scan() {
265+
nStr, ok := strings.CutPrefix(scan.Text(), "read-")
266+
require.True(t, ok)
267+
n, err := strconv.Atoi(nStr)
268+
require.NoError(t, err)
269+
recved[n] = true
270+
break
271+
}
272+
}()
273+
}
274+
readsDone.Wait()
275+
require.NoError(t, conn.Close())
276+
require.Equal(t, allTrue, recved)
277+
}
278+
279+
// Test_ConcurrentCloseReadPacket tests if gorillaConn can concurrently be closed while reading.
280+
func Test_ConcurrentCloseReadPacket(t *testing.T) {
281+
t.Skip("TODO: figure out a good way to synchronize CloseRead and Reads")
282+
283+
const numReads = 100
284+
var readsDone sync.WaitGroup
285+
readsDone.Add(numReads)
286+
287+
ts, conn := setupAndConnectToTestUDPWebSocketServer(t, func(svrConn transport.StreamConn) {
288+
readsDone.Wait()
289+
svrConn.Close()
290+
})
291+
defer ts.Close()
292+
293+
// Concurrently Close while reading
294+
for range numReads {
295+
go func() {
296+
defer readsDone.Done()
297+
io.ReadAll(conn)
298+
}()
299+
}
300+
require.NoError(t, conn.Close())
301+
readsDone.Wait()
302+
}
303+
304+
// --- Test Helpers ---
305+
306+
func setupAndConnectToTestUDPWebSocketServer(t *testing.T, server func(transport.StreamConn)) (ts *httptest.Server, conn net.Conn) {
307+
mux := http.NewServeMux()
308+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
309+
svrConn, err := Upgrade(w, r, http.Header{})
310+
require.NoError(t, err)
311+
server(svrConn)
312+
})
313+
mux.Handle("/udp", http.StripPrefix("/udp", handler))
314+
ts = httptest.NewTLSServer(mux)
315+
316+
client := ts.Client()
317+
endpoint := &transport.TCPEndpoint{Address: ts.Listener.Addr().String()}
318+
connect, err := NewPacketEndpoint("wss"+ts.URL[5:]+"/udp", endpoint, WithTLSConfig(client.Transport.(*http.Transport).TLSClientConfig))
319+
require.NoError(t, err)
320+
conn, err = connect(context.Background())
321+
require.NoError(t, err)
322+
323+
return
324+
}

0 commit comments

Comments
 (0)