Skip to content

fix: Add atomic.Bool to ensure wsConnection.Disconnect() is only called once #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) {
}

func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Conn, connectionCallbacks *serverTypes.ConnectionCallbacks) {
agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}}
agentConn := newWSConnection(wsConn)

defer func() {
// Close the connection when all is done.
defer func() {
err := wsConn.Close()
err := agentConn.Disconnect()
if err != nil {
s.logger.Errorf(context.Background(), "error closing the WebSocket connection: %v", err)
}
Expand Down
90 changes: 84 additions & 6 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,26 @@ import (
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"google.golang.org/protobuf/proto"

clienttypes "github.com/open-telemetry/opamp-go/client/types"
sharedinternal "github.com/open-telemetry/opamp-go/internal"
"github.com/open-telemetry/opamp-go/internal/testhelpers"
"github.com/open-telemetry/opamp-go/protobufs"
"github.com/open-telemetry/opamp-go/server/types"
)

func startServer(t *testing.T, settings *StartSettings) *server {
srv := New(&sharedinternal.NopLogger{})
return startServerWithLogger(t, settings, &sharedinternal.NopLogger{})

}

func startServerWithLogger(t *testing.T, settings *StartSettings, logger clienttypes.Logger) *server {
srv := New(logger)
require.NotNil(t, srv)
if settings.ListenEndpoint == "" {
// Find an avaiable port to listne on.
// Find an avaiable port to listen on.
settings.ListenEndpoint = testhelpers.GetAvailableLocalAddress()
}
if settings.ListenPath == "" {
Expand Down Expand Up @@ -239,7 +246,7 @@ func TestDisconnectHttpConnection(t *testing.T) {
assert.Equal(t, ErrInvalidHTTPConnection, err)
}

func TestDisconnectWSConnection(t *testing.T) {
func TestDisconnectClientWSConnection(t *testing.T) {
connectionCloseCalled := int32(0)
callback := types.Callbacks{
OnConnecting: func(request *http.Request) types.ConnectionResponse {
Expand All @@ -263,10 +270,77 @@ func TestDisconnectWSConnection(t *testing.T) {
assert.NoError(t, err)
assert.True(t, atomic.LoadInt32(&connectionCloseCalled) == 0)

// Close connection from server side
srvConn := wsConnection{wsConn: conn}
err = srvConn.Disconnect()
// Close connection from client side
clientConn := newWSConnection(conn)
err = clientConn.Disconnect()
assert.NoError(t, err)

// Verify connection disconnected from server side
eventually(t, func() bool { return atomic.LoadInt32(&connectionCloseCalled) == 1 })
// Waiting for wsConnection to fail ReadMessage() over a Disconnected communication
eventually(t, func() bool {
_, _, err := conn.ReadMessage()
return err != nil
})
}

// testLogger is a struct that adapts a *zap.Logger to opamp-go's Logger interface.
type testLogger struct {
logs []string
}

func newTestLogger() *testLogger {
return &testLogger{
logs: []string{},
}
}

func (o *testLogger) Debugf(_ context.Context, format string, v ...any) {
log := fmt.Sprintf(format, v...)
o.logs = append(o.logs, fmt.Sprintf("Debugf: %s\n", log))
}

func (o *testLogger) Errorf(_ context.Context, format string, v ...any) {
log := fmt.Sprintf(format, v...)
o.logs = append(o.logs, fmt.Sprintf("Errorf: %s\n", log))
}

func TestDisconnectServerWSConnection(t *testing.T) {
connectionCloseCalled := int32(0)
var serverConn types.Connection
connReady := make(chan struct{}) // Channel to signal when serverConn is assigned
callback := types.Callbacks{
OnConnecting: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{
OnConnected: func(ctx context.Context, conn types.Connection) {
serverConn = conn
close(connReady)
},
OnConnectionClose: func(conn types.Connection) {
atomic.StoreInt32(&connectionCloseCalled, 1)
},
}}
},
}

// Start a Server.
logger := newTestLogger()
settings := &StartSettings{Settings: Settings{Callbacks: callback}}
srv := startServerWithLogger(t, settings, logger)
defer srv.Stop(context.Background())

// Connect to the Server.
conn, _, err := dialClient(settings)

// Verify that the connection is successful.
assert.NoError(t, err)
assert.True(t, atomic.LoadInt32(&connectionCloseCalled) == 0)

// Wait for serverConn to be assigned
<-connReady

// Close connection from server side
serverConn.Disconnect()

// Verify connection disconnected from server side
eventually(t, func() bool { return atomic.LoadInt32(&connectionCloseCalled) == 1 })
Expand All @@ -275,6 +349,10 @@ func TestDisconnectWSConnection(t *testing.T) {
_, _, err := conn.ReadMessage()
return err != nil
})

require.Equal(t, 1, len(logger.logs))
require.Contains(t, logger.logs[0], "Errorf: Cannot read a message from WebSocket: read tcp")
require.Contains(t, logger.logs[0], "use of closed network connection")
Comment on lines +353 to +355
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care about the log messages produced? It looks fragile to make the tests dependent on the exact formatting of the messages.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree we should avoid fragility. The original motivation for this change was to deal with superfluous error logs produced by the server-side Disconnect() calls. Open to suggestions for a better criteria for to test for.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps as a compromise: in testLogger store error logs separate from debug logs and here only assert that there is an error log produced, but don't check the error message itself. This should be fairly stable since we always expect exactly one error log and don't really care about the message itself.

}

var testInstanceUid = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}
Expand Down
17 changes: 13 additions & 4 deletions server/wsconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"sync"
"sync/atomic"

"github.com/gorilla/websocket"

Expand All @@ -17,23 +18,31 @@ type wsConnection struct {
// The websocket library does not allow multiple concurrent write operations,
// so ensure that we only have a single operation in progress at a time.
// For more: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency
connMutex *sync.Mutex
connMutex sync.Mutex
wsConn *websocket.Conn
closed atomic.Bool
}

var _ types.Connection = (*wsConnection)(nil)

func (c wsConnection) Connection() net.Conn {
func newWSConnection(wsConn *websocket.Conn) types.Connection {
return &wsConnection{wsConn: wsConn, connMutex: sync.Mutex{}, closed: atomic.Bool{}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think zero-initialization for connMutex and closed is fine, e.g. this should work the same way and is less verbose:

Suggested change
return &wsConnection{wsConn: wsConn, connMutex: sync.Mutex{}, closed: atomic.Bool{}}
return &wsConnection{wsConn: wsConn}

}

func (c *wsConnection) Connection() net.Conn {
return c.wsConn.UnderlyingConn()
}

func (c wsConnection) Send(_ context.Context, message *protobufs.ServerToAgent) error {
func (c *wsConnection) Send(_ context.Context, message *protobufs.ServerToAgent) error {
c.connMutex.Lock()
defer c.connMutex.Unlock()

return internal.WriteWSMessage(c.wsConn, message)
}

func (c wsConnection) Disconnect() error {
func (c *wsConnection) Disconnect() error {
if !c.closed.CompareAndSwap(false, true) {
return nil
}
return c.wsConn.Close()
}