Skip to content

Commit e57b2d6

Browse files
committed
feat(compress): Create compress writer only if the content is compressible
1 parent 19694e2 commit e57b2d6

File tree

2 files changed

+102
-15
lines changed

2 files changed

+102
-15
lines changed

middleware/compress.go

+50-13
Original file line numberDiff line numberDiff line change
@@ -193,29 +193,26 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
193193
// current Compressor.
194194
func (c *Compressor) Handler(next http.Handler) http.Handler {
195195
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
196-
encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
196+
encoder, encoding := c.selectEncoder(r.Header, w)
197197

198198
cw := &compressResponseWriter{
199199
ResponseWriter: w,
200-
w: w,
201200
contentTypes: c.allowedTypes,
202201
contentWildcards: c.allowedWildcards,
203202
encoding: encoding,
204203
compressible: false, // determined in post-handler
205204
}
206205
if encoder != nil {
207-
cw.w = encoder
206+
cw.encoder = encoder
208207
}
209-
// Re-add the encoder to the pool if applicable.
210-
defer cleanup()
211208
defer cw.Close()
212209

213210
next.ServeHTTP(cw, r)
214211
})
215212
}
216213

217214
// selectEncoder returns the encoder, the name of the encoder, and a closer function.
218-
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
215+
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (func() io.Writer, string) {
219216
header := h.Get("Accept-Encoding")
220217

221218
// Parse the names of all accepted algorithms from the header.
@@ -225,23 +222,31 @@ func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, strin
225222
for _, name := range c.encodingPrecedence {
226223
if matchAcceptEncoding(accepted, name) {
227224
if pool, ok := c.pooledEncoders[name]; ok {
228-
encoder := pool.Get().(ioResetterWriter)
229-
cleanup := func() {
230-
pool.Put(encoder)
225+
fn := func() io.Writer {
226+
enc := pool.Get().(ioResetterWriter)
227+
enc.Reset(w)
228+
return &pooledEncoder{
229+
Writer: enc,
230+
pool: pool,
231+
}
231232
}
232-
encoder.Reset(w)
233-
return encoder, name, cleanup
233+
return fn, name
234234

235235
}
236236
if fn, ok := c.encoders[name]; ok {
237-
return fn(w, c.level), name, func() {}
237+
fn := func() io.Writer {
238+
return &encoder{
239+
Writer: fn(w, c.level),
240+
}
241+
}
242+
return fn, name
238243
}
239244
}
240245

241246
}
242247

243248
// No encoder found to match the accepted encoding
244-
return nil, "", func() {}
249+
return nil, ""
245250
}
246251

247252
func matchAcceptEncoding(accepted []string, encoding string) bool {
@@ -276,6 +281,8 @@ type compressResponseWriter struct {
276281
encoding string
277282
wroteHeader bool
278283
compressible bool
284+
285+
encoder func() io.Writer
279286
}
280287

281288
func (cw *compressResponseWriter) isCompressible() bool {
@@ -335,6 +342,9 @@ func (cw *compressResponseWriter) Write(p []byte) (int, error) {
335342

336343
func (cw *compressResponseWriter) writer() io.Writer {
337344
if cw.compressible {
345+
if cw.w == nil {
346+
cw.w = cw.encoder()
347+
}
338348
return cw.w
339349
}
340350
return cw.ResponseWriter
@@ -385,6 +395,33 @@ func (cw *compressResponseWriter) Unwrap() http.ResponseWriter {
385395
return cw.ResponseWriter
386396
}
387397

398+
type (
399+
encoder struct {
400+
io.Writer
401+
}
402+
403+
pooledEncoder struct {
404+
io.Writer
405+
pool *sync.Pool
406+
}
407+
)
408+
409+
func (e *encoder) Close() error {
410+
if c, ok := e.Writer.(io.WriteCloser); ok {
411+
return c.Close()
412+
}
413+
return nil
414+
}
415+
416+
func (e *pooledEncoder) Close() error {
417+
var err error
418+
if w, ok := e.Writer.(io.WriteCloser); ok {
419+
err = w.Close()
420+
}
421+
e.pool.Put(e.Writer)
422+
return err
423+
}
424+
388425
func encoderGzip(w io.Writer, level int) io.Writer {
389426
gw, err := gzip.NewWriterLevel(w, level)
390427
if err != nil {

middleware/compress_test.go

+52-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ func TestCompressor(t *testing.T) {
2727
return w
2828
})
2929

30-
if len(compressor.encoders) != 1 {
31-
t.Errorf("nop encoder should be stored in the encoders map")
30+
var sideEffect int
31+
compressor.SetEncoder("test", func(w io.Writer, _ int) io.Writer {
32+
return newSideEffectWriter(w, &sideEffect)
33+
})
34+
35+
if len(compressor.encoders) != 2 {
36+
t.Errorf("nop and test encoders should be stored in the encoders map")
3237
}
3338

3439
r.Use(compressor.Handler)
@@ -48,6 +53,11 @@ func TestCompressor(t *testing.T) {
4853
w.Write([]byte("textstring"))
4954
})
5055

56+
r.Get("/getimage", func(w http.ResponseWriter, r *http.Request) {
57+
w.Header().Set("Content-Type", "image/png")
58+
w.Write([]byte("textstring"))
59+
})
60+
5161
ts := httptest.NewServer(r)
5262
defer ts.Close()
5363

@@ -98,6 +108,20 @@ func TestCompressor(t *testing.T) {
98108
expectedEncoding: "nop",
99109
checkRawResponse: true,
100110
},
111+
{
112+
name: "test encoder is used",
113+
path: "/getimage",
114+
acceptedEncodings: []string{"test"},
115+
expectedEncoding: "",
116+
checkRawResponse: true,
117+
},
118+
{
119+
name: "test encoder is used and Close is called",
120+
path: "/gethtml",
121+
acceptedEncodings: []string{"test"},
122+
expectedEncoding: "test",
123+
checkRawResponse: true,
124+
},
101125
}
102126

103127
for _, tc := range tests {
@@ -117,6 +141,9 @@ func TestCompressor(t *testing.T) {
117141
})
118142

119143
}
144+
if sideEffect != 0 {
145+
t.Errorf("side effect should be cleared after close")
146+
}
120147
}
121148

122149
func TestCompressorWildcards(t *testing.T) {
@@ -254,3 +281,26 @@ func decodeResponseBody(t *testing.T, resp *http.Response) string {
254281

255282
return string(respBody)
256283
}
284+
285+
type (
286+
sideEffectWriter struct {
287+
w io.Writer
288+
s *int
289+
}
290+
)
291+
292+
func newSideEffectWriter(w io.Writer, sideEffect *int) io.Writer {
293+
*sideEffect = *sideEffect + 1
294+
295+
return &sideEffectWriter{w: w, s: sideEffect}
296+
}
297+
298+
func (w *sideEffectWriter) Write(p []byte) (n int, err error) {
299+
return w.w.Write(p)
300+
}
301+
302+
func (w *sideEffectWriter) Close() error {
303+
*w.s = *w.s - 1
304+
305+
return nil
306+
}

0 commit comments

Comments
 (0)