Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add atomic.Bool to ensure wsConnection.Disconnect() is only called once #359

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net"
"net/http"
"sync"
"sync/atomic"

"github.com/gorilla/websocket"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -225,12 +226,12 @@ func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) {
}

func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Conn, connectionCallbacks *serverTypes.ConnectionCallbacks) {
agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}}
agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}, closed: &atomic.Bool{}}

defer func() {
// Close the connection when all is done.
defer func() {
err := wsConn.Close()
err := agentConn.Disconnect()
if err != nil {
s.logger.Errorf(context.Background(), "error closing the WebSocket connection: %v", err)
}
Expand Down
99 changes: 94 additions & 5 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ import (
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"google.golang.org/protobuf/proto"

clienttypes "github.com/open-telemetry/opamp-go/client/types"
sharedinternal "github.com/open-telemetry/opamp-go/internal"
"github.com/open-telemetry/opamp-go/internal/testhelpers"
"github.com/open-telemetry/opamp-go/protobufs"
Expand All @@ -28,7 +30,23 @@ func startServer(t *testing.T, settings *StartSettings) *server {
srv := New(&sharedinternal.NopLogger{})
require.NotNil(t, srv)
if settings.ListenEndpoint == "" {
// Find an avaiable port to listne on.
// Find an avaiable port to listen on.
settings.ListenEndpoint = testhelpers.GetAvailableLocalAddress()
}
if settings.ListenPath == "" {
settings.ListenPath = "/"
}
err := srv.Start(*settings)
require.NoError(t, err)

return srv
}

func startServerWithLogger(t *testing.T, settings *StartSettings, logger clienttypes.Logger) *server {
srv := New(logger)
require.NotNil(t, srv)
if settings.ListenEndpoint == "" {
// Find an avaiable port to listen on.
settings.ListenEndpoint = testhelpers.GetAvailableLocalAddress()
}
if settings.ListenPath == "" {
Expand Down Expand Up @@ -239,7 +257,7 @@ func TestDisconnectHttpConnection(t *testing.T) {
assert.Equal(t, ErrInvalidHTTPConnection, err)
}

func TestDisconnectWSConnection(t *testing.T) {
func TestDisconnectClientWSConnection(t *testing.T) {
connectionCloseCalled := int32(0)
callback := types.Callbacks{
OnConnecting: func(request *http.Request) types.ConnectionResponse {
Expand All @@ -263,10 +281,77 @@ func TestDisconnectWSConnection(t *testing.T) {
assert.NoError(t, err)
assert.True(t, atomic.LoadInt32(&connectionCloseCalled) == 0)

// Close connection from server side
srvConn := wsConnection{wsConn: conn}
err = srvConn.Disconnect()
// Close connection from client side
clientConn := wsConnection{wsConn: conn, connMutex: &sync.Mutex{}, closed: &atomic.Bool{}}
err = clientConn.Disconnect()
assert.NoError(t, err)

// Verify connection disconnected from server side
eventually(t, func() bool { return atomic.LoadInt32(&connectionCloseCalled) == 1 })
// Waiting for wsConnection to fail ReadMessage() over a Disconnected communication
eventually(t, func() bool {
_, _, err := conn.ReadMessage()
return err != nil
})
}

// testLogger is a struct that adapts a *zap.Logger to opamp-go's Logger interface.
type testLogger struct {
logs []string
}

func newTestLogger() *testLogger {
return &testLogger{
logs: []string{},
}
}

func (o *testLogger) Debugf(_ context.Context, format string, v ...any) {
log := fmt.Sprintf(format, v...)
o.logs = append(o.logs, fmt.Sprintf("Debugf: %s\n", log))
}

func (o *testLogger) Errorf(_ context.Context, format string, v ...any) {
log := fmt.Sprintf(format, v...)
o.logs = append(o.logs, fmt.Sprintf("Errorf: %s\n", log))
}

func TestDisconnectServerWSConnection(t *testing.T) {
connectionCloseCalled := int32(0)
var serverConn types.Connection
mutex := sync.Mutex{}
callback := types.Callbacks{
OnConnecting: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: types.ConnectionCallbacks{
OnConnected: func(ctx context.Context, conn types.Connection) {
mutex.Lock()
serverConn = conn
mutex.Unlock()
},
OnConnectionClose: func(conn types.Connection) {
atomic.StoreInt32(&connectionCloseCalled, 1)
},
}}
},
}

// Start a Server.
logger := newTestLogger()
settings := &StartSettings{Settings: Settings{Callbacks: callback}}
srv := startServerWithLogger(t, settings, logger)
defer srv.Stop(context.Background())

// Connect to the Server.
conn, _, err := dialClient(settings)

// Verify that the connection is successful.
assert.NoError(t, err)
assert.True(t, atomic.LoadInt32(&connectionCloseCalled) == 0)

// Close connection from server side
mutex.Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mutex protect against? Does serverConn value ever change (except from nil to the only connection we get)? Can serverConn be nil here?

serverConn.Disconnect()
mutex.Unlock()

// Verify connection disconnected from server side
eventually(t, func() bool { return atomic.LoadInt32(&connectionCloseCalled) == 1 })
Expand All @@ -275,6 +360,10 @@ func TestDisconnectWSConnection(t *testing.T) {
_, _, err := conn.ReadMessage()
return err != nil
})

require.Equal(t, 1, len(logger.logs))
require.Contains(t, logger.logs[0], "Errorf: Cannot read a message from WebSocket: read tcp")
require.Contains(t, logger.logs[0], "use of closed network connection")
Comment on lines +364 to +366
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care about the log messages produced? It looks fragile to make the tests dependent on the exact formatting of the messages.

}

var testInstanceUid = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}
Expand Down
6 changes: 6 additions & 0 deletions server/wsconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"sync"
"sync/atomic"

"github.com/gorilla/websocket"

Expand All @@ -19,6 +20,7 @@ type wsConnection struct {
// For more: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency
connMutex *sync.Mutex
wsConn *websocket.Conn
closed *atomic.Bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a pointer? Seems to be like unnecessary extra allocation.

}

var _ types.Connection = (*wsConnection)(nil)
Expand All @@ -35,5 +37,9 @@ func (c wsConnection) Send(_ context.Context, message *protobufs.ServerToAgent)
}

func (c wsConnection) Disconnect() error {
if c.closed.Load() {
return nil
}
c.closed.Store(true)
Comment on lines +40 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure this happens once, I think the correct way is to use the atomic CompareAndSwap instead of separate Load/Store which can race.

return c.wsConn.Close()
}
Loading