Skip to content

Commit de44c53

Browse files
Add support for TLS WebSocket proxy (#2762)
* Add support for TLS WebSocket proxy * support tls to non-tls and non-tls to tls websocket proxy
1 parent c44f628 commit de44c53

File tree

2 files changed

+248
-5
lines changed

2 files changed

+248
-5
lines changed

middleware/proxy.go

+18-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package middleware
55

66
import (
77
"context"
8+
"crypto/tls"
89
"fmt"
910
"io"
1011
"math/rand"
@@ -130,21 +131,33 @@ var DefaultProxyConfig = ProxyConfig{
130131
ContextKey: "target",
131132
}
132133

133-
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
134+
func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
135+
var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
136+
if transport, ok := config.Transport.(*http.Transport); ok {
137+
if transport.TLSClientConfig != nil {
138+
d := tls.Dialer{
139+
Config: transport.TLSClientConfig,
140+
}
141+
dialFunc = d.DialContext
142+
}
143+
}
144+
if dialFunc == nil {
145+
var d net.Dialer
146+
dialFunc = d.DialContext
147+
}
148+
134149
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135150
in, _, err := c.Response().Hijack()
136151
if err != nil {
137152
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
138153
return
139154
}
140155
defer in.Close()
141-
142-
out, err := net.Dial("tcp", t.URL.Host)
156+
out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
143157
if err != nil {
144158
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
145159
return
146160
}
147-
defer out.Close()
148161

149162
// Write header
150163
err = r.Write(out)
@@ -365,7 +378,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
365378
// Proxy
366379
switch {
367380
case c.IsWebSocket():
368-
proxyRaw(tgt, c).ServeHTTP(res, req)
381+
proxyRaw(tgt, c, config).ServeHTTP(res, req)
369382
default: // even SSE requests
370383
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
371384
}

middleware/proxy_test.go

+230
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package middleware
66
import (
77
"bytes"
88
"context"
9+
"crypto/tls"
910
"errors"
1011
"fmt"
1112
"io"
@@ -20,6 +21,7 @@ import (
2021

2122
"github.com/labstack/echo/v4"
2223
"github.com/stretchr/testify/assert"
24+
"golang.org/x/net/websocket"
2325
)
2426

2527
// Assert expected with url.EscapedPath method to obtain the path.
@@ -810,3 +812,231 @@ func TestModifyResponseUseContext(t *testing.T) {
810812
assert.Equal(t, "OK", rec.Body.String())
811813
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
812814
}
815+
816+
func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
817+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
818+
wsHandler := func(conn *websocket.Conn) {
819+
defer conn.Close()
820+
for {
821+
var msg string
822+
err := websocket.Message.Receive(conn, &msg)
823+
if err != nil {
824+
return
825+
}
826+
// message back to the client
827+
websocket.Message.Send(conn, msg)
828+
}
829+
}
830+
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
831+
})
832+
if serveTLS {
833+
return httptest.NewTLSServer(handler)
834+
}
835+
return httptest.NewServer(handler)
836+
}
837+
838+
func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
839+
e := echo.New()
840+
841+
if toTLS {
842+
// proxy to tls target
843+
tgtURL, _ := url.Parse(srv.URL)
844+
tgtURL.Scheme = "wss"
845+
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
846+
847+
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
848+
if !ok {
849+
t.Fatal("Default transport is not of type *http.Transport")
850+
}
851+
transport := defaultTransport.Clone()
852+
transport.TLSClientConfig = &tls.Config{
853+
InsecureSkipVerify: true,
854+
}
855+
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
856+
} else {
857+
// proxy to non-TLS target
858+
tgtURL, _ := url.Parse(srv.URL)
859+
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
860+
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
861+
}
862+
863+
if serveTLS {
864+
// serve proxy server with TLS
865+
ts := httptest.NewTLSServer(e)
866+
return ts
867+
}
868+
// serve proxy server without TLS
869+
ts := httptest.NewServer(e)
870+
return ts
871+
}
872+
873+
// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
874+
func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
875+
/*
876+
Arrange
877+
*/
878+
// Create a WebSocket test server (non-TLS)
879+
srv := createSimpleWebSocketServer(false)
880+
defer srv.Close()
881+
882+
// create proxy server (non-TLS to non-TLS)
883+
ts := createSimpleProxyServer(t, srv, false, false)
884+
defer ts.Close()
885+
886+
tsURL, _ := url.Parse(ts.URL)
887+
tsURL.Scheme = "ws"
888+
tsURL.Path = "/"
889+
890+
/*
891+
Act
892+
*/
893+
894+
// Connect to the proxy WebSocket
895+
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
896+
assert.NoError(t, err)
897+
defer wsConn.Close()
898+
899+
// Send message
900+
sendMsg := "Hello, Non TLS WebSocket!"
901+
err = websocket.Message.Send(wsConn, sendMsg)
902+
assert.NoError(t, err)
903+
904+
/*
905+
Assert
906+
*/
907+
// Read response
908+
var recvMsg string
909+
err = websocket.Message.Receive(wsConn, &recvMsg)
910+
assert.NoError(t, err)
911+
assert.Equal(t, sendMsg, recvMsg)
912+
}
913+
914+
// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
915+
func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
916+
/*
917+
Arrange
918+
*/
919+
// Create a WebSocket test server (TLS)
920+
srv := createSimpleWebSocketServer(true)
921+
defer srv.Close()
922+
923+
// create proxy server (TLS to TLS)
924+
ts := createSimpleProxyServer(t, srv, true, true)
925+
defer ts.Close()
926+
927+
tsURL, _ := url.Parse(ts.URL)
928+
tsURL.Scheme = "wss"
929+
tsURL.Path = "/"
930+
931+
/*
932+
Act
933+
*/
934+
origin, err := url.Parse(ts.URL)
935+
assert.NoError(t, err)
936+
config := &websocket.Config{
937+
Location: tsURL,
938+
Origin: origin,
939+
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
940+
Version: websocket.ProtocolVersionHybi13,
941+
}
942+
wsConn, err := websocket.DialConfig(config)
943+
assert.NoError(t, err)
944+
defer wsConn.Close()
945+
946+
// Send message
947+
sendMsg := "Hello, TLS to TLS WebSocket!"
948+
err = websocket.Message.Send(wsConn, sendMsg)
949+
assert.NoError(t, err)
950+
951+
// Read response
952+
var recvMsg string
953+
err = websocket.Message.Receive(wsConn, &recvMsg)
954+
assert.NoError(t, err)
955+
assert.Equal(t, sendMsg, recvMsg)
956+
}
957+
958+
// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
959+
func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
960+
/*
961+
Arrange
962+
*/
963+
964+
// Create a WebSocket test server (TLS)
965+
srv := createSimpleWebSocketServer(true)
966+
defer srv.Close()
967+
968+
// create proxy server (Non-TLS to TLS)
969+
ts := createSimpleProxyServer(t, srv, false, true)
970+
defer ts.Close()
971+
972+
tsURL, _ := url.Parse(ts.URL)
973+
tsURL.Scheme = "ws"
974+
tsURL.Path = "/"
975+
976+
/*
977+
Act
978+
*/
979+
// Connect to the proxy WebSocket
980+
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
981+
assert.NoError(t, err)
982+
defer wsConn.Close()
983+
984+
// Send message
985+
sendMsg := "Hello, Non TLS to TLS WebSocket!"
986+
err = websocket.Message.Send(wsConn, sendMsg)
987+
assert.NoError(t, err)
988+
989+
/*
990+
Assert
991+
*/
992+
// Read response
993+
var recvMsg string
994+
err = websocket.Message.Receive(wsConn, &recvMsg)
995+
assert.NoError(t, err)
996+
assert.Equal(t, sendMsg, recvMsg)
997+
}
998+
999+
// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
1000+
func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
1001+
/*
1002+
Arrange
1003+
*/
1004+
1005+
// Create a WebSocket test server (non-TLS)
1006+
srv := createSimpleWebSocketServer(false)
1007+
defer srv.Close()
1008+
1009+
// create proxy server (TLS to non-TLS)
1010+
ts := createSimpleProxyServer(t, srv, true, false)
1011+
defer ts.Close()
1012+
1013+
tsURL, _ := url.Parse(ts.URL)
1014+
tsURL.Scheme = "wss"
1015+
tsURL.Path = "/"
1016+
1017+
/*
1018+
Act
1019+
*/
1020+
origin, err := url.Parse(ts.URL)
1021+
assert.NoError(t, err)
1022+
config := &websocket.Config{
1023+
Location: tsURL,
1024+
Origin: origin,
1025+
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
1026+
Version: websocket.ProtocolVersionHybi13,
1027+
}
1028+
wsConn, err := websocket.DialConfig(config)
1029+
assert.NoError(t, err)
1030+
defer wsConn.Close()
1031+
1032+
// Send message
1033+
sendMsg := "Hello, TLS to NoneTLS WebSocket!"
1034+
err = websocket.Message.Send(wsConn, sendMsg)
1035+
assert.NoError(t, err)
1036+
1037+
// Read response
1038+
var recvMsg string
1039+
err = websocket.Message.Receive(wsConn, &recvMsg)
1040+
assert.NoError(t, err)
1041+
assert.Equal(t, sendMsg, recvMsg)
1042+
}

0 commit comments

Comments
 (0)