Skip to content

Commit dd830ea

Browse files
committed
refactor(api): use websocket connection to receive events from SSH session
1 parent a03f5f5 commit dd830ea

File tree

14 files changed

+1434
-128
lines changed

14 files changed

+1434
-128
lines changed

api/routes/handler.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@ package routes
22

33
import (
44
svc "github.com/shellhub-io/shellhub/api/services"
5+
"github.com/shellhub-io/shellhub/pkg/websocket"
56
)
67

78
type Handler struct {
89
service svc.Service
10+
// WebSocketUpgrader is used to turns a HTTP request into WebSocketUpgrader connection.
11+
WebSocketUpgrader websocket.Upgrader
912
}
1013

11-
func NewHandler(s svc.Service) *Handler {
12-
return &Handler{service: s}
14+
func NewHandler(s svc.Service, w websocket.Upgrader) *Handler {
15+
return &Handler{
16+
service: s,
17+
WebSocketUpgrader: w,
18+
}
1319
}

api/routes/healthcheck_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
func TestEvaluateHealth(t *testing.T) {
1515
e := echo.New()
1616
mock := new(mocks.Service)
17-
h := NewHandler(mock)
17+
h := NewHandler(mock, nil)
1818

1919
cases := []struct {
2020
title string

api/routes/routes.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/shellhub-io/shellhub/pkg/api/authorizer"
1414
"github.com/shellhub-io/shellhub/pkg/envs"
1515
pkgmiddleware "github.com/shellhub-io/shellhub/pkg/middleware"
16+
"github.com/shellhub-io/shellhub/pkg/websocket"
1617
)
1718

1819
type DefaultHTTPHandlerConfig struct {
@@ -66,7 +67,7 @@ func WithReporter(reporter *sentry.Client) Option {
6667
func NewRouter(service services.Service, opts ...Option) *echo.Echo {
6768
router := DefaultHTTPHandler(service, new(DefaultHTTPHandlerConfig)).(*echo.Echo)
6869

69-
handler := NewHandler(service)
70+
handler := NewHandler(service, websocket.NewGorillaWebSocketUpgrader())
7071
for _, opt := range opts {
7172
if err := opt(router, handler); err != nil {
7273
return nil
@@ -90,7 +91,7 @@ func NewRouter(service services.Service, opts ...Option) *echo.Echo {
9091
internalAPI.GET(GetPublicKeyURL, gateway.Handler(handler.GetPublicKey))
9192
internalAPI.POST(CreatePrivateKeyURL, gateway.Handler(handler.CreatePrivateKey))
9293
internalAPI.POST(EvaluateKeyURL, gateway.Handler(handler.EvaluateKey))
93-
internalAPI.POST(EventsSessionsURL, gateway.Handler(handler.EventSession))
94+
internalAPI.GET(EventsSessionsURL, gateway.Handler(handler.EventSession))
9495

9596
// Public routes for external access through API gateway
9697
publicAPI := router.Group("/api")

api/routes/session.go

+43-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"github.com/shellhub-io/shellhub/pkg/api/query"
99
"github.com/shellhub-io/shellhub/pkg/api/requests"
1010
"github.com/shellhub-io/shellhub/pkg/models"
11+
"github.com/shellhub-io/shellhub/pkg/websocket"
12+
log "github.com/sirupsen/logrus"
1113
)
1214

1315
const (
@@ -123,7 +125,8 @@ func (h *Handler) KeepAliveSession(c gateway.Context) error {
123125
}
124126

125127
func (h *Handler) EventSession(c gateway.Context) error {
126-
var req requests.SessionEvent
128+
var req requests.SessionIDParam
129+
127130
if err := c.Bind(&req); err != nil {
128131
return err
129132
}
@@ -132,11 +135,43 @@ func (h *Handler) EventSession(c gateway.Context) error {
132135
return err
133136
}
134137

135-
return h.service.EventSession(c.Ctx(), models.UID(req.UID), &models.SessionEvent{
136-
Session: req.UID,
137-
Type: models.SessionEventType(req.Type),
138-
Timestamp: req.Timestamp,
139-
Data: req.Data,
140-
Seat: req.Seat,
141-
})
138+
if !c.IsWebSocket() {
139+
return c.NoContent(http.StatusBadRequest)
140+
}
141+
142+
connection, err := h.WebSocketUpgrader.Upgrade(c.Response(), c.Request())
143+
if err != nil {
144+
return c.NoContent(http.StatusBadRequest)
145+
}
146+
147+
defer connection.Close()
148+
149+
var r requests.SessionEvent
150+
for {
151+
if err := connection.ReadJSON(&r); err != nil {
152+
if websocket.IsErrorCloseNormal(err) || websocket.IsUnexpectedCloseError(err) {
153+
log.WithError(err).WithFields(log.Fields{
154+
"uid": req.UID,
155+
}).Debug("events websocket closed with a ignored error")
156+
157+
return nil
158+
}
159+
160+
return err
161+
}
162+
163+
if err := c.Validate(&r); err != nil {
164+
return err
165+
}
166+
167+
if err := h.service.EventSession(c.Ctx(), models.UID(req.UID), &models.SessionEvent{
168+
Session: req.UID,
169+
Type: models.SessionEventType(r.Type),
170+
Timestamp: r.Timestamp,
171+
Data: r.Data,
172+
Seat: r.Seat,
173+
}); err != nil {
174+
return err
175+
}
176+
}
142177
}

api/routes/session_test.go

+153-3
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,25 @@ package routes
22

33
import (
44
"encoding/json"
5+
"errors"
56
"fmt"
67
"io"
78
"net/http"
89
"net/http/httptest"
910
"strings"
1011
"testing"
1112

13+
"github.com/gorilla/websocket"
14+
"github.com/labstack/echo/v4"
15+
"github.com/shellhub-io/shellhub/api/pkg/gateway"
1216
svc "github.com/shellhub-io/shellhub/api/services"
13-
14-
"github.com/shellhub-io/shellhub/api/store"
15-
1617
"github.com/shellhub-io/shellhub/api/services/mocks"
18+
"github.com/shellhub-io/shellhub/api/store"
1719
"github.com/shellhub-io/shellhub/pkg/api/authorizer"
1820
"github.com/shellhub-io/shellhub/pkg/api/query"
1921
"github.com/shellhub-io/shellhub/pkg/api/requests"
2022
"github.com/shellhub-io/shellhub/pkg/models"
23+
websocketmocks "github.com/shellhub-io/shellhub/pkg/websocket/mocks"
2124
"github.com/stretchr/testify/assert"
2225
gomock "github.com/stretchr/testify/mock"
2326
)
@@ -312,3 +315,150 @@ func TestFinishSession(t *testing.T) {
312315

313316
mock.AssertExpectations(t)
314317
}
318+
319+
func TestEventSession(t *testing.T) {
320+
mock := new(mocks.Service)
321+
webSocketUpgraderMock := new(websocketmocks.Upgrader)
322+
323+
cases := []struct {
324+
description string
325+
uid string
326+
seat int
327+
requiredMocks func(uid string)
328+
expected int
329+
}{
330+
{
331+
description: "fails when upgrade cannot be done",
332+
uid: "123",
333+
seat: 0,
334+
requiredMocks: func(_ string) {
335+
webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(nil, errors.New("")).Once()
336+
},
337+
expected: http.StatusBadRequest,
338+
},
339+
{
340+
description: "fails when cannot read from websocket due error",
341+
uid: "123",
342+
seat: 0,
343+
requiredMocks: func(_ string) {
344+
conn := new(websocketmocks.Conn)
345+
conn.On("Close").Return(nil).Once()
346+
conn.On("ReadJSON", gomock.Anything).Return(io.EOF).Once()
347+
348+
webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(conn, nil).Once()
349+
},
350+
expected: http.StatusInternalServerError,
351+
},
352+
{
353+
description: "fails when cannot read from websocket due generic error",
354+
uid: "123",
355+
seat: 0,
356+
requiredMocks: func(_ string) {
357+
conn := new(websocketmocks.Conn)
358+
conn.On("Close").Return(nil).Once()
359+
conn.On("ReadJSON", gomock.Anything).Return(errors.New("")).Once()
360+
361+
webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(conn, nil).Once()
362+
},
363+
expected: http.StatusInternalServerError,
364+
},
365+
{
366+
description: "fails when record frame is invalid",
367+
uid: "123",
368+
seat: 0,
369+
requiredMocks: func(_ string) {
370+
conn := new(websocketmocks.Conn)
371+
conn.On("Close").Return(nil).Once()
372+
conn.On("ReadJSON", gomock.Anything).Return(nil).Once()
373+
374+
webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(conn, nil).Once()
375+
},
376+
expected: http.StatusBadRequest,
377+
},
378+
{
379+
description: "fails to write the frame on the database",
380+
uid: "123",
381+
seat: 0,
382+
requiredMocks: func(uid string) {
383+
conn := new(websocketmocks.Conn)
384+
conn.On("Close").Return(nil).Once()
385+
conn.On("NextReader").Return().Once()
386+
conn.On("ReadJSON", gomock.Anything).Return(nil).Once().Run(func(args gomock.Arguments) {
387+
req := args.Get(0).(*requests.SessionEvent) //nolint:forcetypeassert
388+
389+
json.
390+
NewDecoder(strings.NewReader(`{"type":"pty-output","timestamp":"2025-02-03T14:11:32.405Z","data": { "output":"test" },"seat": 0}`)).
391+
Decode(req) //nolint:errcheck
392+
})
393+
394+
webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(conn, nil).Once()
395+
396+
mock.On("EventSession", gomock.Anything, models.UID(uid), gomock.Anything).
397+
Return(errors.New("not able record")).Once()
398+
},
399+
expected: http.StatusInternalServerError,
400+
},
401+
{
402+
description: "success to write one frame on database",
403+
uid: "123",
404+
seat: 0,
405+
requiredMocks: func(uid string) {
406+
conn := new(websocketmocks.Conn)
407+
conn.On("Close").Return(nil).Once()
408+
conn.On("NextReader").Return().Once()
409+
conn.On("ReadJSON", gomock.Anything).Return(nil).Once().Run(func(args gomock.Arguments) {
410+
req := args.Get(0).(*requests.SessionEvent) //nolint:forcetypeassert
411+
412+
json.
413+
NewDecoder(strings.NewReader(`{"type":"pty-output","timestamp":"2025-02-03T14:11:32.405Z","data": { "output":"test" },"seat": 0}`)).
414+
Decode(req) //nolint:errcheck
415+
})
416+
417+
webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(conn, nil).Once()
418+
419+
mock.On("EventSession", gomock.Anything, models.UID(uid),
420+
gomock.Anything).Return(nil).Once()
421+
422+
conn.On("ReadJSON", gomock.Anything).Return(&websocket.CloseError{
423+
Code: 1000,
424+
Text: "test",
425+
}).Once()
426+
},
427+
expected: http.StatusOK,
428+
},
429+
}
430+
431+
for _, tc := range cases {
432+
t.Run(tc.description, func(t *testing.T) {
433+
tc.requiredMocks(tc.uid)
434+
435+
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("ws:///internal/sessions/%s/events", tc.uid), nil)
436+
req.Header.Set("Content-Type", echo.MIMEApplicationJSON)
437+
req.Header.Set("X-Role", authorizer.RoleOwner.String())
438+
req.Header.Set("Upgrade", "websocket")
439+
req.Header.Set("Connection", "Upgrade")
440+
req.Header.Set("Sec-WebSocket-Version", "13")
441+
req.Header.Set("Sec-WebSocket-Key", "test")
442+
443+
e := NewRouter(mock, func(_ *echo.Echo, handler *Handler) error {
444+
handler.WebSocketUpgrader = webSocketUpgraderMock
445+
446+
return nil
447+
})
448+
449+
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
450+
return func(c echo.Context) error {
451+
ctx := gateway.NewContext(mock, c)
452+
453+
return next(ctx)
454+
}
455+
})
456+
457+
rec := httptest.NewRecorder()
458+
e.ServeHTTP(rec, req)
459+
460+
assert.Equal(t, tc.expected, rec.Result().StatusCode)
461+
mock.AssertExpectations(t)
462+
})
463+
}
464+
}

0 commit comments

Comments
 (0)