@@ -2,22 +2,25 @@ package routes
2
2
3
3
import (
4
4
"encoding/json"
5
+ "errors"
5
6
"fmt"
6
7
"io"
7
8
"net/http"
8
9
"net/http/httptest"
9
10
"strings"
10
11
"testing"
11
12
13
+ "github.com/gorilla/websocket"
14
+ "github.com/labstack/echo/v4"
15
+ "github.com/shellhub-io/shellhub/api/pkg/gateway"
12
16
svc "github.com/shellhub-io/shellhub/api/services"
13
-
14
- "github.com/shellhub-io/shellhub/api/store"
15
-
16
17
"github.com/shellhub-io/shellhub/api/services/mocks"
18
+ "github.com/shellhub-io/shellhub/api/store"
17
19
"github.com/shellhub-io/shellhub/pkg/api/authorizer"
18
20
"github.com/shellhub-io/shellhub/pkg/api/query"
19
21
"github.com/shellhub-io/shellhub/pkg/api/requests"
20
22
"github.com/shellhub-io/shellhub/pkg/models"
23
+ websocketmocks "github.com/shellhub-io/shellhub/pkg/websocket/mocks"
21
24
"github.com/stretchr/testify/assert"
22
25
gomock "github.com/stretchr/testify/mock"
23
26
)
@@ -312,3 +315,150 @@ func TestFinishSession(t *testing.T) {
312
315
313
316
mock .AssertExpectations (t )
314
317
}
318
+
319
+ func TestEventSession (t * testing.T ) {
320
+ mock := new (mocks.Service )
321
+ webSocketUpgraderMock := new (websocketmocks.Upgrader )
322
+
323
+ cases := []struct {
324
+ description string
325
+ uid string
326
+ seat int
327
+ requiredMocks func (uid string )
328
+ expected int
329
+ }{
330
+ {
331
+ description : "fails when upgrade cannot be done" ,
332
+ uid : "123" ,
333
+ seat : 0 ,
334
+ requiredMocks : func (_ string ) {
335
+ webSocketUpgraderMock .On ("Upgrade" , gomock .Anything , gomock .Anything ).Return (nil , errors .New ("" )).Once ()
336
+ },
337
+ expected : http .StatusBadRequest ,
338
+ },
339
+ {
340
+ description : "fails when cannot read from websocket due error" ,
341
+ uid : "123" ,
342
+ seat : 0 ,
343
+ requiredMocks : func (_ string ) {
344
+ conn := new (websocketmocks.Conn )
345
+ conn .On ("Close" ).Return (nil ).Once ()
346
+ conn .On ("ReadJSON" , gomock .Anything ).Return (io .EOF ).Once ()
347
+
348
+ webSocketUpgraderMock .On ("Upgrade" , gomock .Anything , gomock .Anything ).Return (conn , nil ).Once ()
349
+ },
350
+ expected : http .StatusInternalServerError ,
351
+ },
352
+ {
353
+ description : "fails when cannot read from websocket due generic error" ,
354
+ uid : "123" ,
355
+ seat : 0 ,
356
+ requiredMocks : func (_ string ) {
357
+ conn := new (websocketmocks.Conn )
358
+ conn .On ("Close" ).Return (nil ).Once ()
359
+ conn .On ("ReadJSON" , gomock .Anything ).Return (errors .New ("" )).Once ()
360
+
361
+ webSocketUpgraderMock .On ("Upgrade" , gomock .Anything , gomock .Anything ).Return (conn , nil ).Once ()
362
+ },
363
+ expected : http .StatusInternalServerError ,
364
+ },
365
+ {
366
+ description : "fails when record frame is invalid" ,
367
+ uid : "123" ,
368
+ seat : 0 ,
369
+ requiredMocks : func (_ string ) {
370
+ conn := new (websocketmocks.Conn )
371
+ conn .On ("Close" ).Return (nil ).Once ()
372
+ conn .On ("ReadJSON" , gomock .Anything ).Return (nil ).Once ()
373
+
374
+ webSocketUpgraderMock .On ("Upgrade" , gomock .Anything , gomock .Anything ).Return (conn , nil ).Once ()
375
+ },
376
+ expected : http .StatusBadRequest ,
377
+ },
378
+ {
379
+ description : "fails to write the frame on the database" ,
380
+ uid : "123" ,
381
+ seat : 0 ,
382
+ requiredMocks : func (uid string ) {
383
+ conn := new (websocketmocks.Conn )
384
+ conn .On ("Close" ).Return (nil ).Once ()
385
+ conn .On ("NextReader" ).Return ().Once ()
386
+ conn .On ("ReadJSON" , gomock .Anything ).Return (nil ).Once ().Run (func (args gomock.Arguments ) {
387
+ req := args .Get (0 ).(* requests.SessionEvent ) //nolint:forcetypeassert
388
+
389
+ json .
390
+ NewDecoder (strings .NewReader (`{"type":"pty-output","timestamp":"2025-02-03T14:11:32.405Z","data": { "output":"test" },"seat": 0}` )).
391
+ Decode (req ) //nolint:errcheck
392
+ })
393
+
394
+ webSocketUpgraderMock .On ("Upgrade" , gomock .Anything , gomock .Anything ).Return (conn , nil ).Once ()
395
+
396
+ mock .On ("EventSession" , gomock .Anything , models .UID (uid ), gomock .Anything ).
397
+ Return (errors .New ("not able record" )).Once ()
398
+ },
399
+ expected : http .StatusInternalServerError ,
400
+ },
401
+ {
402
+ description : "success to write one frame on database" ,
403
+ uid : "123" ,
404
+ seat : 0 ,
405
+ requiredMocks : func (uid string ) {
406
+ conn := new (websocketmocks.Conn )
407
+ conn .On ("Close" ).Return (nil ).Once ()
408
+ conn .On ("NextReader" ).Return ().Once ()
409
+ conn .On ("ReadJSON" , gomock .Anything ).Return (nil ).Once ().Run (func (args gomock.Arguments ) {
410
+ req := args .Get (0 ).(* requests.SessionEvent ) //nolint:forcetypeassert
411
+
412
+ json .
413
+ NewDecoder (strings .NewReader (`{"type":"pty-output","timestamp":"2025-02-03T14:11:32.405Z","data": { "output":"test" },"seat": 0}` )).
414
+ Decode (req ) //nolint:errcheck
415
+ })
416
+
417
+ webSocketUpgraderMock .On ("Upgrade" , gomock .Anything , gomock .Anything ).Return (conn , nil ).Once ()
418
+
419
+ mock .On ("EventSession" , gomock .Anything , models .UID (uid ),
420
+ gomock .Anything ).Return (nil ).Once ()
421
+
422
+ conn .On ("ReadJSON" , gomock .Anything ).Return (& websocket.CloseError {
423
+ Code : 1000 ,
424
+ Text : "test" ,
425
+ }).Once ()
426
+ },
427
+ expected : http .StatusOK ,
428
+ },
429
+ }
430
+
431
+ for _ , tc := range cases {
432
+ t .Run (tc .description , func (t * testing.T ) {
433
+ tc .requiredMocks (tc .uid )
434
+
435
+ req := httptest .NewRequest (http .MethodGet , fmt .Sprintf ("ws:///internal/sessions/%s/events" , tc .uid ), nil )
436
+ req .Header .Set ("Content-Type" , echo .MIMEApplicationJSON )
437
+ req .Header .Set ("X-Role" , authorizer .RoleOwner .String ())
438
+ req .Header .Set ("Upgrade" , "websocket" )
439
+ req .Header .Set ("Connection" , "Upgrade" )
440
+ req .Header .Set ("Sec-WebSocket-Version" , "13" )
441
+ req .Header .Set ("Sec-WebSocket-Key" , "test" )
442
+
443
+ e := NewRouter (mock , func (_ * echo.Echo , handler * Handler ) error {
444
+ handler .WebSocketUpgrader = webSocketUpgraderMock
445
+
446
+ return nil
447
+ })
448
+
449
+ e .Use (func (next echo.HandlerFunc ) echo.HandlerFunc {
450
+ return func (c echo.Context ) error {
451
+ ctx := gateway .NewContext (mock , c )
452
+
453
+ return next (ctx )
454
+ }
455
+ })
456
+
457
+ rec := httptest .NewRecorder ()
458
+ e .ServeHTTP (rec , req )
459
+
460
+ assert .Equal (t , tc .expected , rec .Result ().StatusCode )
461
+ mock .AssertExpectations (t )
462
+ })
463
+ }
464
+ }
0 commit comments