@@ -193,29 +193,26 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
193
193
// current Compressor.
194
194
func (c * Compressor ) Handler (next http.Handler ) http.Handler {
195
195
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
196
- encoder , encoding , cleanup := c .selectEncoder (r .Header , w )
196
+ encoder , encoding := c .selectEncoder (r .Header , w )
197
197
198
198
cw := & compressResponseWriter {
199
199
ResponseWriter : w ,
200
- w : w ,
201
200
contentTypes : c .allowedTypes ,
202
201
contentWildcards : c .allowedWildcards ,
203
202
encoding : encoding ,
204
203
compressible : false , // determined in post-handler
205
204
}
206
205
if encoder != nil {
207
- cw .w = encoder
206
+ cw .encoder = encoder
208
207
}
209
- // Re-add the encoder to the pool if applicable.
210
- defer cleanup ()
211
208
defer cw .Close ()
212
209
213
210
next .ServeHTTP (cw , r )
214
211
})
215
212
}
216
213
217
214
// selectEncoder returns the encoder, the name of the encoder, and a closer function.
218
- func (c * Compressor ) selectEncoder (h http.Header , w io.Writer ) (io.Writer , string , func () ) {
215
+ func (c * Compressor ) selectEncoder (h http.Header , w io.Writer ) (func () io.Writer , string ) {
219
216
header := h .Get ("Accept-Encoding" )
220
217
221
218
// Parse the names of all accepted algorithms from the header.
@@ -225,23 +222,31 @@ func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, strin
225
222
for _ , name := range c .encodingPrecedence {
226
223
if matchAcceptEncoding (accepted , name ) {
227
224
if pool , ok := c .pooledEncoders [name ]; ok {
228
- encoder := pool .Get ().(ioResetterWriter )
229
- cleanup := func () {
230
- pool .Put (encoder )
225
+ fn := func () io.Writer {
226
+ enc := pool .Get ().(ioResetterWriter )
227
+ enc .Reset (w )
228
+ return & pooledEncoder {
229
+ Writer : enc ,
230
+ pool : pool ,
231
+ }
231
232
}
232
- encoder .Reset (w )
233
- return encoder , name , cleanup
233
+ return fn , name
234
234
235
235
}
236
236
if fn , ok := c .encoders [name ]; ok {
237
- return fn (w , c .level ), name , func () {}
237
+ fn := func () io.Writer {
238
+ return & encoder {
239
+ Writer : fn (w , c .level ),
240
+ }
241
+ }
242
+ return fn , name
238
243
}
239
244
}
240
245
241
246
}
242
247
243
248
// No encoder found to match the accepted encoding
244
- return nil , "" , func () {}
249
+ return nil , ""
245
250
}
246
251
247
252
func matchAcceptEncoding (accepted []string , encoding string ) bool {
@@ -276,6 +281,8 @@ type compressResponseWriter struct {
276
281
encoding string
277
282
wroteHeader bool
278
283
compressible bool
284
+
285
+ encoder func () io.Writer
279
286
}
280
287
281
288
func (cw * compressResponseWriter ) isCompressible () bool {
@@ -335,6 +342,9 @@ func (cw *compressResponseWriter) Write(p []byte) (int, error) {
335
342
336
343
func (cw * compressResponseWriter ) writer () io.Writer {
337
344
if cw .compressible {
345
+ if cw .w == nil {
346
+ cw .w = cw .encoder ()
347
+ }
338
348
return cw .w
339
349
}
340
350
return cw .ResponseWriter
@@ -385,6 +395,33 @@ func (cw *compressResponseWriter) Unwrap() http.ResponseWriter {
385
395
return cw .ResponseWriter
386
396
}
387
397
398
+ type (
399
+ encoder struct {
400
+ io.Writer
401
+ }
402
+
403
+ pooledEncoder struct {
404
+ io.Writer
405
+ pool * sync.Pool
406
+ }
407
+ )
408
+
409
+ func (e * encoder ) Close () error {
410
+ if c , ok := e .Writer .(io.WriteCloser ); ok {
411
+ return c .Close ()
412
+ }
413
+ return nil
414
+ }
415
+
416
+ func (e * pooledEncoder ) Close () error {
417
+ var err error
418
+ if w , ok := e .Writer .(io.WriteCloser ); ok {
419
+ err = w .Close ()
420
+ }
421
+ e .pool .Put (e .Writer )
422
+ return err
423
+ }
424
+
388
425
func encoderGzip (w io.Writer , level int ) io.Writer {
389
426
gw , err := gzip .NewWriterLevel (w , level )
390
427
if err != nil {
0 commit comments