-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Add atomic.Bool to ensure wsConnection.Disconnect() is only called once #359
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,8 +16,10 @@ import ( | |
"github.com/gorilla/websocket" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
|
||
"google.golang.org/protobuf/proto" | ||
|
||
clienttypes "github.com/open-telemetry/opamp-go/client/types" | ||
sharedinternal "github.com/open-telemetry/opamp-go/internal" | ||
"github.com/open-telemetry/opamp-go/internal/testhelpers" | ||
"github.com/open-telemetry/opamp-go/protobufs" | ||
|
@@ -28,7 +30,23 @@ func startServer(t *testing.T, settings *StartSettings) *server { | |
srv := New(&sharedinternal.NopLogger{}) | ||
require.NotNil(t, srv) | ||
if settings.ListenEndpoint == "" { | ||
// Find an avaiable port to listne on. | ||
// Find an avaiable port to listen on. | ||
settings.ListenEndpoint = testhelpers.GetAvailableLocalAddress() | ||
} | ||
if settings.ListenPath == "" { | ||
settings.ListenPath = "/" | ||
} | ||
err := srv.Start(*settings) | ||
require.NoError(t, err) | ||
|
||
return srv | ||
} | ||
|
||
func startServerWithLogger(t *testing.T, settings *StartSettings, logger clienttypes.Logger) *server { | ||
srv := New(logger) | ||
require.NotNil(t, srv) | ||
if settings.ListenEndpoint == "" { | ||
// Find an avaiable port to listen on. | ||
settings.ListenEndpoint = testhelpers.GetAvailableLocalAddress() | ||
} | ||
if settings.ListenPath == "" { | ||
|
@@ -239,7 +257,7 @@ func TestDisconnectHttpConnection(t *testing.T) { | |
assert.Equal(t, ErrInvalidHTTPConnection, err) | ||
} | ||
|
||
func TestDisconnectWSConnection(t *testing.T) { | ||
func TestDisconnectClientWSConnection(t *testing.T) { | ||
connectionCloseCalled := int32(0) | ||
callback := types.Callbacks{ | ||
OnConnecting: func(request *http.Request) types.ConnectionResponse { | ||
|
@@ -263,10 +281,77 @@ func TestDisconnectWSConnection(t *testing.T) { | |
assert.NoError(t, err) | ||
assert.True(t, atomic.LoadInt32(&connectionCloseCalled) == 0) | ||
|
||
// Close connection from server side | ||
srvConn := wsConnection{wsConn: conn} | ||
err = srvConn.Disconnect() | ||
// Close connection from client side | ||
clientConn := wsConnection{wsConn: conn, connMutex: &sync.Mutex{}, closed: &atomic.Bool{}} | ||
err = clientConn.Disconnect() | ||
assert.NoError(t, err) | ||
|
||
// Verify connection disconnected from server side | ||
eventually(t, func() bool { return atomic.LoadInt32(&connectionCloseCalled) == 1 }) | ||
// Waiting for wsConnection to fail ReadMessage() over a Disconnected communication | ||
eventually(t, func() bool { | ||
_, _, err := conn.ReadMessage() | ||
return err != nil | ||
}) | ||
} | ||
|
||
// testLogger is a struct that adapts a *zap.Logger to opamp-go's Logger interface. | ||
type testLogger struct { | ||
logs []string | ||
} | ||
|
||
func newTestLogger() *testLogger { | ||
return &testLogger{ | ||
logs: []string{}, | ||
} | ||
} | ||
|
||
func (o *testLogger) Debugf(_ context.Context, format string, v ...any) { | ||
log := fmt.Sprintf(format, v...) | ||
o.logs = append(o.logs, fmt.Sprintf("Debugf: %s\n", log)) | ||
} | ||
|
||
func (o *testLogger) Errorf(_ context.Context, format string, v ...any) { | ||
log := fmt.Sprintf(format, v...) | ||
o.logs = append(o.logs, fmt.Sprintf("Errorf: %s\n", log)) | ||
} | ||
|
||
func TestDisconnectServerWSConnection(t *testing.T) { | ||
connectionCloseCalled := int32(0) | ||
var serverConn types.Connection | ||
mutex := sync.Mutex{} | ||
callback := types.Callbacks{ | ||
OnConnecting: func(request *http.Request) types.ConnectionResponse { | ||
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{ | ||
OnConnected: func(ctx context.Context, conn types.Connection) { | ||
mutex.Lock() | ||
serverConn = conn | ||
mutex.Unlock() | ||
}, | ||
OnConnectionClose: func(conn types.Connection) { | ||
atomic.StoreInt32(&connectionCloseCalled, 1) | ||
}, | ||
}} | ||
}, | ||
} | ||
|
||
// Start a Server. | ||
logger := newTestLogger() | ||
settings := &StartSettings{Settings: Settings{Callbacks: callback}} | ||
srv := startServerWithLogger(t, settings, logger) | ||
defer srv.Stop(context.Background()) | ||
|
||
// Connect to the Server. | ||
conn, _, err := dialClient(settings) | ||
|
||
// Verify that the connection is successful. | ||
assert.NoError(t, err) | ||
assert.True(t, atomic.LoadInt32(&connectionCloseCalled) == 0) | ||
|
||
// Close connection from server side | ||
mutex.Lock() | ||
serverConn.Disconnect() | ||
mutex.Unlock() | ||
|
||
// Verify connection disconnected from server side | ||
eventually(t, func() bool { return atomic.LoadInt32(&connectionCloseCalled) == 1 }) | ||
|
@@ -275,6 +360,10 @@ func TestDisconnectWSConnection(t *testing.T) { | |
_, _, err := conn.ReadMessage() | ||
return err != nil | ||
}) | ||
|
||
require.Equal(t, 1, len(logger.logs)) | ||
require.Contains(t, logger.logs[0], "Errorf: Cannot read a message from WebSocket: read tcp") | ||
require.Contains(t, logger.logs[0], "use of closed network connection") | ||
Comment on lines
+364
to
+366
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we care about the log messages produced? It looks fragile to make the tests dependent on the exact formatting of the messages. |
||
} | ||
|
||
var testInstanceUid = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ import ( | |
"context" | ||
"net" | ||
"sync" | ||
"sync/atomic" | ||
|
||
"github.com/gorilla/websocket" | ||
|
||
|
@@ -19,6 +20,7 @@ type wsConnection struct { | |
// For more: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency | ||
connMutex *sync.Mutex | ||
wsConn *websocket.Conn | ||
closed *atomic.Bool | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be a pointer? Seems to be like unnecessary extra allocation. |
||
} | ||
|
||
var _ types.Connection = (*wsConnection)(nil) | ||
|
@@ -35,5 +37,9 @@ func (c wsConnection) Send(_ context.Context, message *protobufs.ServerToAgent) | |
} | ||
|
||
func (c wsConnection) Disconnect() error { | ||
if c.closed.Load() { | ||
return nil | ||
} | ||
c.closed.Store(true) | ||
Comment on lines
+40
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To ensure this happens once, I think the correct way is to use the atomic |
||
return c.wsConn.Close() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mutex protect against? Does
serverConn
value ever change (except fromnil
to the only connection we get)? CanserverConn
be nil here?