2121import com .facebook .airlift .http .client .StaticBodyGenerator ;
2222import com .facebook .airlift .log .Logger ;
2323import com .facebook .airlift .units .Duration ;
24+ import com .github .luben .zstd .ZstdInputStream ;
25+ import com .github .luben .zstd .ZstdOutputStreamNoFinalizer ;
2426import com .google .common .base .Splitter ;
2527import com .google .common .collect .ArrayListMultimap ;
2628import com .google .common .collect .ListMultimap ;
2729import com .google .common .util .concurrent .SettableFuture ;
2830import com .google .inject .Inject ;
2931import io .netty .channel .ChannelOption ;
32+ import io .netty .channel .WriteBufferWaterMark ;
3033import io .netty .channel .epoll .Epoll ;
3134import io .netty .handler .codec .http .HttpHeaders ;
3235import io .netty .handler .ssl .ApplicationProtocolConfig ;
4447import reactor .netty .resources .ConnectionProvider ;
4548import reactor .netty .resources .LoopResources ;
4649
50+ import java .io .ByteArrayOutputStream ;
4751import java .io .Closeable ;
4852import java .io .File ;
4953import java .io .IOException ;
6266import java .util .concurrent .TimeUnit ;
6367import java .util .concurrent .TimeoutException ;
6468import java .util .function .Function ;
69+ import java .util .zip .GZIPInputStream ;
6570
6671import static com .facebook .airlift .security .pem .PemReader .loadPrivateKey ;
6772import static com .facebook .airlift .security .pem .PemReader .readCertificateChain ;
@@ -84,17 +89,25 @@ public class ReactorNettyHttpClient
8489 private static final Logger log = Logger .get (ReactorNettyHttpClient .class );
8590 private static final HeaderName CONTENT_TYPE_HEADER_NAME = HeaderName .of ("Content-Type" );
8691 private static final HeaderName CONTENT_LENGTH_HEADER_NAME = HeaderName .of ("Content-Length" );
92+ private static final HeaderName CONTENT_ENCODING_HEADER_NAME = HeaderName .of ("Content-Encoding" );
93+ private static final HeaderName ACCEPT_ENCODING_HEADER_NAME = HeaderName .of ("Accept-Encoding" );
8794
8895 private final Duration requestTimeout ;
8996 private HttpClient httpClient ;
9097 private final HttpClientConnectionPoolStats connectionPoolStats ;
9198 private final HttpClientStats httpClientStats ;
99+ private final boolean isHttp2CompressionEnabled ;
100+ private final int payloadSizeThreshold ;
101+ private final double compressionSavingThreshold ;
92102
93103 @ Inject
94104 public ReactorNettyHttpClient (ReactorNettyHttpClientConfig config , HttpClientConnectionPoolStats connectionPoolStats , HttpClientStats httpClientStats )
95105 {
96106 this .connectionPoolStats = connectionPoolStats ;
97107 this .httpClientStats = httpClientStats ;
108+ this .isHttp2CompressionEnabled = config .isHttp2CompressionEnabled ();
109+ this .payloadSizeThreshold = config .getPayloadSizeThreshold ();
110+ this .compressionSavingThreshold = config .getCompressionSavingThreshold ();
98111 SslContext sslContext = null ;
99112 if (config .isHttpsEnabled ()) {
100113 try {
@@ -114,11 +127,11 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
114127 if (os .toLowerCase (Locale .ENGLISH ).contains ("linux" )) {
115128 // Make sure Open ssl is available for linux deployments
116129 if (!OpenSsl .isAvailable ()) {
117- throw new UnsupportedOperationException (format ("OpenSsl is not unavailable . Stacktrace: %s" , Arrays .toString (OpenSsl .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
130+ throw new UnsupportedOperationException (format ("OpenSsl is not available . Stacktrace: %s" , Arrays .toString (OpenSsl .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
118131 }
119132 // Make sure epoll threads are used for linux deployments
120133 if (!Epoll .isAvailable ()) {
121- throw new UnsupportedOperationException (format ("Epoll is not unavailable . Stacktrace: %s" , Arrays .toString (Epoll .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
134+ throw new UnsupportedOperationException (format ("Epoll is not available . Stacktrace: %s" , Arrays .toString (Epoll .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
122135 }
123136 }
124137
@@ -166,9 +179,10 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
166179
167180 // Create HTTP/2 client
168181 SslContext finalSslContext = sslContext ;
182+
169183 this .httpClient = HttpClient
170- // The custom pool is wrapped with a HttpConnectionProvider over here
171- .create ( pool )
184+ . create ( pool ) // The custom pool is wrapped with a HttpConnectionProvider over here
185+ .compress ( false ) // we will enable response compression manually
172186 .protocol (HttpProtocol .H2 , HttpProtocol .HTTP11 )
173187 .runOn (loopResources , true )
174188 .http2Settings (settings -> {
@@ -179,6 +193,9 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
179193 .option (ChannelOption .CONNECT_TIMEOUT_MILLIS , (int ) config .getConnectTimeout ().getValue ())
180194 .option (ChannelOption .SO_KEEPALIVE , true )
181195 .option (ChannelOption .TCP_NODELAY , true )
196+ .option (ChannelOption .SO_SNDBUF , config .getTcpBufferSize ())
197+ .option (ChannelOption .SO_RCVBUF , config .getTcpBufferSize ())
198+ .option (ChannelOption .WRITE_BUFFER_WATER_MARK , new WriteBufferWaterMark (config .getWriteBufferWaterMarkLow (), config .getWriteBufferWaterMarkHigh ()))
182199 // Track HTTP client metrics
183200 .metrics (true , () -> httpClientStats , Function .identity ());
184201
@@ -208,6 +225,10 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
208225 for (Map .Entry <String , String > entry : airliftRequest .getHeaders ().entries ()) {
209226 hdr .set (entry .getKey (), entry .getValue ());
210227 }
228+
229+ if (isHttp2CompressionEnabled ) {
230+ hdr .set (ACCEPT_ENCODING_HEADER_NAME .toString (), "zstd, gzip" );
231+ }
211232 });
212233
213234 URI uri = airliftRequest .getUri ();
@@ -223,9 +244,34 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
223244 break ;
224245 case "POST" :
225246 byte [] postBytes = ((StaticBodyGenerator ) airliftRequest .getBodyGenerator ()).getBody ();
226- disposable = client .post ()
247+ byte [] bodyToSend = postBytes ;
248+ HttpClient postClient = client ;
249+ // We manually do compression for request, use zstd
250+ if (isHttp2CompressionEnabled && postBytes .length >= payloadSizeThreshold ) {
251+ try {
252+ ByteArrayOutputStream baos = new ByteArrayOutputStream (postBytes .length / 2 );
253+ try (ZstdOutputStreamNoFinalizer zstdOutput = new ZstdOutputStreamNoFinalizer (baos )) {
254+ zstdOutput .write (postBytes );
255+ }
256+
257+ byte [] compressedBytes = baos .toByteArray ();
258+ double compressionRatio = (double ) (postBytes .length - compressedBytes .length ) / postBytes .length ;
259+ if (compressionRatio >= compressionSavingThreshold ) {
260+ bodyToSend = compressedBytes ;
261+ postClient = client .headers (h -> h .set (CONTENT_ENCODING_HEADER_NAME .toString (), "zstd" ));
262+ }
263+ }
264+ catch (IOException e ) {
265+ log .error (e , "Fail to compress POST request body" );
266+ onError (listenableFuture , e );
267+ disposable = () -> {};
268+ break ;
269+ }
270+ }
271+
272+ disposable = postClient .post ()
227273 .uri (uri )
228- .send (ByteBufFlux .fromInbound (Mono .just (postBytes )))
274+ .send (ByteBufFlux .fromInbound (Mono .just (bodyToSend )))
229275 .responseSingle ((response , bytes ) -> bytes .asInputStream ().zipWith (Mono .just (response )))
230276 // Request timeout
231277 .timeout (java .time .Duration .of (requestTimeout .toMillis (), MILLIS ))
@@ -303,6 +349,7 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
303349 }
304350
305351 long contentLength = 0 ;
352+ String contentEncoding = null ;
306353 // Iterate over the headers
307354 for (String name : headers .names ()) {
308355 if (name .equalsIgnoreCase (CONTENT_LENGTH_HEADER_NAME .toString ())) {
@@ -313,6 +360,9 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
313360 else if (name .equalsIgnoreCase (CONTENT_TYPE_HEADER_NAME .toString ())) {
314361 responseHeaders .put (CONTENT_TYPE_HEADER_NAME , headers .get (name ));
315362 }
363+ else if (name .equalsIgnoreCase (CONTENT_ENCODING_HEADER_NAME .toString ())) {
364+ contentEncoding = headers .get (name );
365+ }
316366 else {
317367 responseHeaders .put (HeaderName .of (name ), headers .get (name ));
318368 }
@@ -323,7 +373,21 @@ else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) {
323373 return ;
324374 }
325375
376+ final InputStream [] streamHolder = new InputStream [1 ];
377+ streamHolder [0 ] = inputStream ;
326378 try {
379+ if (contentEncoding != null && !contentEncoding .equalsIgnoreCase ("identity" )) {
380+ if (contentEncoding .equalsIgnoreCase ("zstd" )) {
381+ streamHolder [0 ] = new ZstdInputStream (inputStream );
382+ }
383+ else if (contentEncoding .equalsIgnoreCase ("gzip" )) {
384+ streamHolder [0 ] = new GZIPInputStream (inputStream );
385+ }
386+ else {
387+ throw new RuntimeException (format ("Unsupported Content-Encoding: %s. Supported: zstd, gzip." , contentEncoding ));
388+ }
389+ }
390+
327391 long finalContentLength = contentLength ;
328392 Object a = responseHandler .handle (null , new Response ()
329393 {
@@ -349,19 +413,19 @@ public long getBytesRead()
349413 public InputStream getInputStream ()
350414 throws IOException
351415 {
352- return inputStream ;
416+ return streamHolder [ 0 ] ;
353417 }
354418 });
355419 // closing it here to prevent memory leak of bytebuf
356- inputStream .close ();
420+ streamHolder [ 0 ] .close ();
357421 listenableFuture .set (a );
358422 }
359423 catch (Exception e ) {
360424 listenableFuture .setException (e );
361425 }
362426 finally {
363427 try {
364- inputStream .close ();
428+ streamHolder [ 0 ] .close ();
365429 }
366430 catch (IOException e ) {
367431 log .warn (e , "Failed to close input stream" );
0 commit comments