@@ -6,6 +6,7 @@ package middleware
6
6
import (
7
7
"bytes"
8
8
"context"
9
+ "crypto/tls"
9
10
"errors"
10
11
"fmt"
11
12
"io"
@@ -20,6 +21,7 @@ import (
20
21
21
22
"github.com/labstack/echo/v4"
22
23
"github.com/stretchr/testify/assert"
24
+ "golang.org/x/net/websocket"
23
25
)
24
26
25
27
// Assert expected with url.EscapedPath method to obtain the path.
@@ -810,3 +812,231 @@ func TestModifyResponseUseContext(t *testing.T) {
810
812
assert .Equal (t , "OK" , rec .Body .String ())
811
813
assert .Equal (t , "CUSTOM_BALANCER" , rec .Header ().Get ("FROM_BALANCER" ))
812
814
}
815
+
816
+ func createSimpleWebSocketServer (serveTLS bool ) * httptest.Server {
817
+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
818
+ wsHandler := func (conn * websocket.Conn ) {
819
+ defer conn .Close ()
820
+ for {
821
+ var msg string
822
+ err := websocket .Message .Receive (conn , & msg )
823
+ if err != nil {
824
+ return
825
+ }
826
+ // message back to the client
827
+ websocket .Message .Send (conn , msg )
828
+ }
829
+ }
830
+ websocket.Server {Handler : wsHandler }.ServeHTTP (w , r )
831
+ })
832
+ if serveTLS {
833
+ return httptest .NewTLSServer (handler )
834
+ }
835
+ return httptest .NewServer (handler )
836
+ }
837
+
838
+ func createSimpleProxyServer (t * testing.T , srv * httptest.Server , serveTLS bool , toTLS bool ) * httptest.Server {
839
+ e := echo .New ()
840
+
841
+ if toTLS {
842
+ // proxy to tls target
843
+ tgtURL , _ := url .Parse (srv .URL )
844
+ tgtURL .Scheme = "wss"
845
+ balancer := NewRandomBalancer ([]* ProxyTarget {{URL : tgtURL }})
846
+
847
+ defaultTransport , ok := http .DefaultTransport .(* http.Transport )
848
+ if ! ok {
849
+ t .Fatal ("Default transport is not of type *http.Transport" )
850
+ }
851
+ transport := defaultTransport .Clone ()
852
+ transport .TLSClientConfig = & tls.Config {
853
+ InsecureSkipVerify : true ,
854
+ }
855
+ e .Use (ProxyWithConfig (ProxyConfig {Balancer : balancer , Transport : transport }))
856
+ } else {
857
+ // proxy to non-TLS target
858
+ tgtURL , _ := url .Parse (srv .URL )
859
+ balancer := NewRandomBalancer ([]* ProxyTarget {{URL : tgtURL }})
860
+ e .Use (ProxyWithConfig (ProxyConfig {Balancer : balancer }))
861
+ }
862
+
863
+ if serveTLS {
864
+ // serve proxy server with TLS
865
+ ts := httptest .NewTLSServer (e )
866
+ return ts
867
+ }
868
+ // serve proxy server without TLS
869
+ ts := httptest .NewServer (e )
870
+ return ts
871
+ }
872
+
873
+ // TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
874
+ func TestProxyWithConfigWebSocketNonTLS2NonTLS (t * testing.T ) {
875
+ /*
876
+ Arrange
877
+ */
878
+ // Create a WebSocket test server (non-TLS)
879
+ srv := createSimpleWebSocketServer (false )
880
+ defer srv .Close ()
881
+
882
+ // create proxy server (non-TLS to non-TLS)
883
+ ts := createSimpleProxyServer (t , srv , false , false )
884
+ defer ts .Close ()
885
+
886
+ tsURL , _ := url .Parse (ts .URL )
887
+ tsURL .Scheme = "ws"
888
+ tsURL .Path = "/"
889
+
890
+ /*
891
+ Act
892
+ */
893
+
894
+ // Connect to the proxy WebSocket
895
+ wsConn , err := websocket .Dial (tsURL .String (), "" , "http://localhost/" )
896
+ assert .NoError (t , err )
897
+ defer wsConn .Close ()
898
+
899
+ // Send message
900
+ sendMsg := "Hello, Non TLS WebSocket!"
901
+ err = websocket .Message .Send (wsConn , sendMsg )
902
+ assert .NoError (t , err )
903
+
904
+ /*
905
+ Assert
906
+ */
907
+ // Read response
908
+ var recvMsg string
909
+ err = websocket .Message .Receive (wsConn , & recvMsg )
910
+ assert .NoError (t , err )
911
+ assert .Equal (t , sendMsg , recvMsg )
912
+ }
913
+
914
+ // TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
915
+ func TestProxyWithConfigWebSocketTLS2TLS (t * testing.T ) {
916
+ /*
917
+ Arrange
918
+ */
919
+ // Create a WebSocket test server (TLS)
920
+ srv := createSimpleWebSocketServer (true )
921
+ defer srv .Close ()
922
+
923
+ // create proxy server (TLS to TLS)
924
+ ts := createSimpleProxyServer (t , srv , true , true )
925
+ defer ts .Close ()
926
+
927
+ tsURL , _ := url .Parse (ts .URL )
928
+ tsURL .Scheme = "wss"
929
+ tsURL .Path = "/"
930
+
931
+ /*
932
+ Act
933
+ */
934
+ origin , err := url .Parse (ts .URL )
935
+ assert .NoError (t , err )
936
+ config := & websocket.Config {
937
+ Location : tsURL ,
938
+ Origin : origin ,
939
+ TlsConfig : & tls.Config {InsecureSkipVerify : true }, // skip verify for testing
940
+ Version : websocket .ProtocolVersionHybi13 ,
941
+ }
942
+ wsConn , err := websocket .DialConfig (config )
943
+ assert .NoError (t , err )
944
+ defer wsConn .Close ()
945
+
946
+ // Send message
947
+ sendMsg := "Hello, TLS to TLS WebSocket!"
948
+ err = websocket .Message .Send (wsConn , sendMsg )
949
+ assert .NoError (t , err )
950
+
951
+ // Read response
952
+ var recvMsg string
953
+ err = websocket .Message .Receive (wsConn , & recvMsg )
954
+ assert .NoError (t , err )
955
+ assert .Equal (t , sendMsg , recvMsg )
956
+ }
957
+
958
+ // TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
959
+ func TestProxyWithConfigWebSocketNonTLS2TLS (t * testing.T ) {
960
+ /*
961
+ Arrange
962
+ */
963
+
964
+ // Create a WebSocket test server (TLS)
965
+ srv := createSimpleWebSocketServer (true )
966
+ defer srv .Close ()
967
+
968
+ // create proxy server (Non-TLS to TLS)
969
+ ts := createSimpleProxyServer (t , srv , false , true )
970
+ defer ts .Close ()
971
+
972
+ tsURL , _ := url .Parse (ts .URL )
973
+ tsURL .Scheme = "ws"
974
+ tsURL .Path = "/"
975
+
976
+ /*
977
+ Act
978
+ */
979
+ // Connect to the proxy WebSocket
980
+ wsConn , err := websocket .Dial (tsURL .String (), "" , "http://localhost/" )
981
+ assert .NoError (t , err )
982
+ defer wsConn .Close ()
983
+
984
+ // Send message
985
+ sendMsg := "Hello, Non TLS to TLS WebSocket!"
986
+ err = websocket .Message .Send (wsConn , sendMsg )
987
+ assert .NoError (t , err )
988
+
989
+ /*
990
+ Assert
991
+ */
992
+ // Read response
993
+ var recvMsg string
994
+ err = websocket .Message .Receive (wsConn , & recvMsg )
995
+ assert .NoError (t , err )
996
+ assert .Equal (t , sendMsg , recvMsg )
997
+ }
998
+
999
+ // TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
1000
+ func TestProxyWithConfigWebSocketTLS2NonTLS (t * testing.T ) {
1001
+ /*
1002
+ Arrange
1003
+ */
1004
+
1005
+ // Create a WebSocket test server (non-TLS)
1006
+ srv := createSimpleWebSocketServer (false )
1007
+ defer srv .Close ()
1008
+
1009
+ // create proxy server (TLS to non-TLS)
1010
+ ts := createSimpleProxyServer (t , srv , true , false )
1011
+ defer ts .Close ()
1012
+
1013
+ tsURL , _ := url .Parse (ts .URL )
1014
+ tsURL .Scheme = "wss"
1015
+ tsURL .Path = "/"
1016
+
1017
+ /*
1018
+ Act
1019
+ */
1020
+ origin , err := url .Parse (ts .URL )
1021
+ assert .NoError (t , err )
1022
+ config := & websocket.Config {
1023
+ Location : tsURL ,
1024
+ Origin : origin ,
1025
+ TlsConfig : & tls.Config {InsecureSkipVerify : true }, // skip verify for testing
1026
+ Version : websocket .ProtocolVersionHybi13 ,
1027
+ }
1028
+ wsConn , err := websocket .DialConfig (config )
1029
+ assert .NoError (t , err )
1030
+ defer wsConn .Close ()
1031
+
1032
+ // Send message
1033
+ sendMsg := "Hello, TLS to NoneTLS WebSocket!"
1034
+ err = websocket .Message .Send (wsConn , sendMsg )
1035
+ assert .NoError (t , err )
1036
+
1037
+ // Read response
1038
+ var recvMsg string
1039
+ err = websocket .Message .Receive (wsConn , & recvMsg )
1040
+ assert .NoError (t , err )
1041
+ assert .Equal (t , sendMsg , recvMsg )
1042
+ }
0 commit comments