Skip to content

Commit bc1e190

Browse files
authored
Allow ResponseWriters to unwrap writers when flushing/hijacking (#2595)
* Allow ResponseWriters to unwrap writers when flushing/hijacking
1 parent 3e04e3e commit bc1e190

11 files changed

+289
-8
lines changed

middleware/body_dump.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package middleware
33
import (
44
"bufio"
55
"bytes"
6+
"errors"
67
"io"
78
"net"
89
"net/http"
@@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
9899
}
99100

100101
func (w *bodyDumpResponseWriter) Flush() {
101-
w.ResponseWriter.(http.Flusher).Flush()
102+
err := responseControllerFlush(w.ResponseWriter)
103+
if err != nil && errors.Is(err, http.ErrNotSupported) {
104+
panic(errors.New("response writer flushing is not supported"))
105+
}
102106
}
103107

104108
func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
105-
return w.ResponseWriter.(http.Hijacker).Hijack()
109+
return responseControllerHijack(w.ResponseWriter)
110+
}
111+
112+
func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
113+
return w.ResponseWriter
106114
}

middleware/body_dump_test.go

+50
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) {
8787
}
8888
})
8989
}
90+
91+
func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
92+
bdrw := bodyDumpResponseWriter{
93+
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
94+
}
95+
96+
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
97+
bdrw.Flush()
98+
})
99+
}
100+
101+
func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
102+
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
103+
bdrw := bodyDumpResponseWriter{
104+
ResponseWriter: &trwu,
105+
}
106+
107+
bdrw.Flush()
108+
assert.Equal(t, 1, trwu.unwrapCalled)
109+
}
110+
111+
func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
112+
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
113+
bdrw := bodyDumpResponseWriter{
114+
ResponseWriter: trwu,
115+
}
116+
117+
result := bdrw.Unwrap()
118+
assert.Equal(t, trwu, result)
119+
}
120+
121+
func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
122+
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
123+
bdrw := bodyDumpResponseWriter{
124+
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
125+
}
126+
127+
_, _, err := bdrw.Hijack()
128+
assert.EqualError(t, err, "can hijack")
129+
}
130+
131+
func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
132+
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
133+
bdrw := bodyDumpResponseWriter{
134+
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
135+
}
136+
137+
_, _, err := bdrw.Hijack()
138+
assert.EqualError(t, err, "feature not supported")
139+
}

middleware/compress.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() {
191191
}
192192

193193
w.Writer.(*gzip.Writer).Flush()
194-
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
195-
flusher.Flush()
196-
}
194+
_ = responseControllerFlush(w.ResponseWriter)
195+
}
196+
197+
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
198+
return w.ResponseWriter
197199
}
198200

199201
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
200-
return w.ResponseWriter.(http.Hijacker).Hijack()
202+
return responseControllerHijack(w.ResponseWriter)
201203
}
202204

203205
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {

middleware/compress_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) {
311311
}
312312
}
313313

314+
func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
315+
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
316+
bdrw := gzipResponseWriter{
317+
ResponseWriter: trwu,
318+
}
319+
320+
result := bdrw.Unwrap()
321+
assert.Equal(t, trwu, result)
322+
}
323+
324+
func TestGzipResponseWriter_CanHijack(t *testing.T) {
325+
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
326+
bdrw := gzipResponseWriter{
327+
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
328+
}
329+
330+
_, _, err := bdrw.Hijack()
331+
assert.EqualError(t, err, "can hijack")
332+
}
333+
334+
func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
335+
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
336+
bdrw := gzipResponseWriter{
337+
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
338+
}
339+
340+
_, _, err := bdrw.Hijack()
341+
assert.EqualError(t, err, "feature not supported")
342+
}
343+
314344
func BenchmarkGzip(b *testing.B) {
315345
e := echo.New()
316346

middleware/middleware_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package middleware
22

33
import (
4+
"bufio"
5+
"errors"
46
"github.com/stretchr/testify/assert"
7+
"net"
58
"net/http"
69
"net/http/httptest"
710
"regexp"
@@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) {
9093
})
9194
}
9295
}
96+
97+
type testResponseWriterNoFlushHijack struct {
98+
}
99+
100+
func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
101+
}
102+
103+
func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
104+
return 0, nil
105+
}
106+
107+
func (w *testResponseWriterNoFlushHijack) Header() http.Header {
108+
return nil
109+
}
110+
111+
type testResponseWriterUnwrapper struct {
112+
unwrapCalled int
113+
rw http.ResponseWriter
114+
}
115+
116+
func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
117+
}
118+
119+
func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
120+
return 0, nil
121+
}
122+
123+
func (w *testResponseWriterUnwrapper) Header() http.Header {
124+
return nil
125+
}
126+
127+
func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
128+
w.unwrapCalled++
129+
return w.rw
130+
}
131+
132+
type testResponseWriterUnwrapperHijack struct {
133+
testResponseWriterUnwrapper
134+
}
135+
136+
func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
137+
return nil, nil, errors.New("can hijack")
138+
}

middleware/responsecontroller_1.19.go

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//go:build !go1.20
2+
3+
package middleware
4+
5+
import (
6+
"bufio"
7+
"fmt"
8+
"net"
9+
"net/http"
10+
)
11+
12+
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
13+
func responseControllerFlush(rw http.ResponseWriter) error {
14+
for {
15+
switch t := rw.(type) {
16+
case interface{ FlushError() error }:
17+
return t.FlushError()
18+
case http.Flusher:
19+
t.Flush()
20+
return nil
21+
case interface{ Unwrap() http.ResponseWriter }:
22+
rw = t.Unwrap()
23+
default:
24+
return fmt.Errorf("%w", http.ErrNotSupported)
25+
}
26+
}
27+
}
28+
29+
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
30+
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
31+
for {
32+
switch t := rw.(type) {
33+
case http.Hijacker:
34+
return t.Hijack()
35+
case interface{ Unwrap() http.ResponseWriter }:
36+
rw = t.Unwrap()
37+
default:
38+
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
39+
}
40+
}
41+
}

middleware/responsecontroller_1.20.go

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//go:build go1.20
2+
3+
package middleware
4+
5+
import (
6+
"bufio"
7+
"net"
8+
"net/http"
9+
)
10+
11+
func responseControllerFlush(rw http.ResponseWriter) error {
12+
return http.NewResponseController(rw).Flush()
13+
}
14+
15+
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
16+
return http.NewResponseController(rw).Hijack()
17+
}

response.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package echo
22

33
import (
44
"bufio"
5+
"errors"
56
"net"
67
"net/http"
78
)
@@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
8485
// buffered data to the client.
8586
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
8687
func (r *Response) Flush() {
87-
r.Writer.(http.Flusher).Flush()
88+
err := responseControllerFlush(r.Writer)
89+
if err != nil && errors.Is(err, http.ErrNotSupported) {
90+
panic(errors.New("response writer flushing is not supported"))
91+
}
8892
}
8993

9094
// Hijack implements the http.Hijacker interface to allow an HTTP handler to
9195
// take over the connection.
9296
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
9397
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
94-
return r.Writer.(http.Hijacker).Hijack()
98+
return responseControllerHijack(r.Writer)
9599
}
96100

97101
// Unwrap returns the original http.ResponseWriter.

response_test.go

+25
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) {
5757
assert.True(t, rec.Flushed)
5858
}
5959

60+
type testResponseWriter struct {
61+
}
62+
63+
func (w *testResponseWriter) WriteHeader(statusCode int) {
64+
}
65+
66+
func (w *testResponseWriter) Write([]byte) (int, error) {
67+
return 0, nil
68+
}
69+
70+
func (w *testResponseWriter) Header() http.Header {
71+
return nil
72+
}
73+
74+
func TestResponse_FlushPanics(t *testing.T) {
75+
e := New()
76+
rw := new(testResponseWriter)
77+
res := &Response{echo: e, Writer: rw}
78+
79+
// we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic
80+
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
81+
res.Flush()
82+
})
83+
}
84+
6085
func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
6186
e := New()
6287
rec := httptest.NewRecorder()

responsecontroller_1.19.go

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//go:build !go1.20
2+
3+
package echo
4+
5+
import (
6+
"bufio"
7+
"fmt"
8+
"net"
9+
"net/http"
10+
)
11+
12+
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
13+
func responseControllerFlush(rw http.ResponseWriter) error {
14+
for {
15+
switch t := rw.(type) {
16+
case interface{ FlushError() error }:
17+
return t.FlushError()
18+
case http.Flusher:
19+
t.Flush()
20+
return nil
21+
case interface{ Unwrap() http.ResponseWriter }:
22+
rw = t.Unwrap()
23+
default:
24+
return fmt.Errorf("%w", http.ErrNotSupported)
25+
}
26+
}
27+
}
28+
29+
// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
30+
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
31+
for {
32+
switch t := rw.(type) {
33+
case http.Hijacker:
34+
return t.Hijack()
35+
case interface{ Unwrap() http.ResponseWriter }:
36+
rw = t.Unwrap()
37+
default:
38+
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
39+
}
40+
}
41+
}

responsecontroller_1.20.go

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//go:build go1.20
2+
3+
package echo
4+
5+
import (
6+
"bufio"
7+
"net"
8+
"net/http"
9+
)
10+
11+
func responseControllerFlush(rw http.ResponseWriter) error {
12+
return http.NewResponseController(rw).Flush()
13+
}
14+
15+
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
16+
return http.NewResponseController(rw).Hijack()
17+
}

0 commit comments

Comments
 (0)