2121import com .facebook .airlift .http .client .StaticBodyGenerator ;
2222import com .facebook .airlift .log .Logger ;
2323import com .facebook .airlift .units .Duration ;
24+ import com .github .luben .zstd .ZstdInputStream ;
25+ import com .github .luben .zstd .ZstdOutputStreamNoFinalizer ;
2426import com .google .common .base .Splitter ;
2527import com .google .common .collect .ArrayListMultimap ;
2628import com .google .common .collect .ListMultimap ;
2729import com .google .common .util .concurrent .SettableFuture ;
2830import com .google .inject .Inject ;
2931import io .netty .channel .ChannelOption ;
32+ import io .netty .channel .WriteBufferWaterMark ;
3033import io .netty .channel .epoll .Epoll ;
3134import io .netty .handler .codec .http .HttpHeaders ;
3235import io .netty .handler .ssl .ApplicationProtocolConfig ;
4447import reactor .netty .resources .ConnectionProvider ;
4548import reactor .netty .resources .LoopResources ;
4649
50+ import java .io .ByteArrayOutputStream ;
4751import java .io .Closeable ;
4852import java .io .File ;
4953import java .io .IOException ;
6266import java .util .concurrent .TimeUnit ;
6367import java .util .concurrent .TimeoutException ;
6468import java .util .function .Function ;
69+ import java .util .zip .GZIPInputStream ;
6570
6671import static com .facebook .airlift .security .pem .PemReader .loadPrivateKey ;
6772import static com .facebook .airlift .security .pem .PemReader .readCertificateChain ;
@@ -84,17 +89,25 @@ public class ReactorNettyHttpClient
8489 private static final Logger log = Logger .get (ReactorNettyHttpClient .class );
8590 private static final HeaderName CONTENT_TYPE_HEADER_NAME = HeaderName .of ("Content-Type" );
8691 private static final HeaderName CONTENT_LENGTH_HEADER_NAME = HeaderName .of ("Content-Length" );
92+ private static final HeaderName CONTENT_ENCODING_HEADER_NAME = HeaderName .of ("Content-Encoding" );
93+ private static final HeaderName ACCEPT_ENCODING_HEADER_NAME = HeaderName .of ("Accept-Encoding" );
8794
8895 private final Duration requestTimeout ;
8996 private HttpClient httpClient ;
9097 private final HttpClientConnectionPoolStats connectionPoolStats ;
9198 private final HttpClientStats httpClientStats ;
99+ private final boolean isHttp2CompressionEnabled ;
100+ private final int payloadSizeThreshold ;
101+ private final double compressionSavingThreshold ;
92102
93103 @ Inject
94104 public ReactorNettyHttpClient (ReactorNettyHttpClientConfig config , HttpClientConnectionPoolStats connectionPoolStats , HttpClientStats httpClientStats )
95105 {
96106 this .connectionPoolStats = connectionPoolStats ;
97107 this .httpClientStats = httpClientStats ;
108+ this .isHttp2CompressionEnabled = config .isHttp2CompressionEnabled ();
109+ this .payloadSizeThreshold = config .getPayloadSizeThreshold ();
110+ this .compressionSavingThreshold = config .getCompressionSavingThreshold ();
98111 SslContext sslContext = null ;
99112 if (config .isHttpsEnabled ()) {
100113 try {
@@ -114,11 +127,11 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
114127 if (os .toLowerCase (Locale .ENGLISH ).contains ("linux" )) {
115128 // Make sure Open ssl is available for linux deployments
116129 if (!OpenSsl .isAvailable ()) {
117- throw new UnsupportedOperationException (format ("OpenSsl is not unavailable . Stacktrace: %s" , Arrays .toString (OpenSsl .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
130+ throw new UnsupportedOperationException (format ("OpenSsl is not available . Stacktrace: %s" , Arrays .toString (OpenSsl .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
118131 }
119132 // Make sure epoll threads are used for linux deployments
120133 if (!Epoll .isAvailable ()) {
121- throw new UnsupportedOperationException (format ("Epoll is not unavailable . Stacktrace: %s" , Arrays .toString (Epoll .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
134+ throw new UnsupportedOperationException (format ("Epoll is not available . Stacktrace: %s" , Arrays .toString (Epoll .unavailabilityCause ().getStackTrace ()).replace (',' , '\n' )));
122135 }
123136 }
124137
@@ -166,9 +179,10 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
166179
167180 // Create HTTP/2 client
168181 SslContext finalSslContext = sslContext ;
182+
169183 this .httpClient = HttpClient
170- // The custom pool is wrapped with a HttpConnectionProvider over here
171- .create ( pool )
184+ . create ( pool ) // The custom pool is wrapped with a HttpConnectionProvider over here
185+ .compress ( false ) // we will enable response compression manually
172186 .protocol (HttpProtocol .H2 , HttpProtocol .HTTP11 )
173187 .runOn (loopResources , true )
174188 .http2Settings (settings -> {
@@ -179,6 +193,9 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
179193 .option (ChannelOption .CONNECT_TIMEOUT_MILLIS , (int ) config .getConnectTimeout ().getValue ())
180194 .option (ChannelOption .SO_KEEPALIVE , true )
181195 .option (ChannelOption .TCP_NODELAY , true )
196+ .option (ChannelOption .SO_SNDBUF , config .getTcpBufferSize ())
197+ .option (ChannelOption .SO_RCVBUF , config .getTcpBufferSize ())
198+ .option (ChannelOption .WRITE_BUFFER_WATER_MARK , new WriteBufferWaterMark (config .getWriteBufferWaterMarkLow (), config .getWriteBufferWaterMarkHigh ()))
182199 // Track HTTP client metrics
183200 .metrics (true , () -> httpClientStats , Function .identity ());
184201
@@ -208,6 +225,10 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
208225 for (Map .Entry <String , String > entry : airliftRequest .getHeaders ().entries ()) {
209226 hdr .set (entry .getKey (), entry .getValue ());
210227 }
228+
229+ if (isHttp2CompressionEnabled ) {
230+ hdr .set (ACCEPT_ENCODING_HEADER_NAME .toString (), "zstd, gzip" );
231+ }
211232 });
212233
213234 URI uri = airliftRequest .getUri ();
@@ -223,9 +244,33 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
223244 break ;
224245 case "POST" :
225246 byte [] postBytes = ((StaticBodyGenerator ) airliftRequest .getBodyGenerator ()).getBody ();
226- disposable = client .post ()
247+ byte [] bodyToSend = postBytes ;
248+ HttpClient postClient = client ;
249+ // We manually do compression for request, use zstd
250+ if (isHttp2CompressionEnabled && postBytes .length >= payloadSizeThreshold ) {
251+ try {
252+ ByteArrayOutputStream baos = new ByteArrayOutputStream (postBytes .length / 2 );
253+ try (ZstdOutputStreamNoFinalizer zstdOutput = new ZstdOutputStreamNoFinalizer (baos )) {
254+ zstdOutput .write (postBytes );
255+ }
256+
257+ byte [] compressedBytes = baos .toByteArray ();
258+ double compressionRatio = (double ) (postBytes .length - compressedBytes .length ) / postBytes .length ;
259+ if (compressionRatio >= compressionSavingThreshold ) {
260+ bodyToSend = compressedBytes ;
261+ postClient = client .headers (h -> h .set (CONTENT_ENCODING_HEADER_NAME .toString (), "zstd" ));
262+ }
263+ }
264+ catch (IOException e ) {
265+ onError (listenableFuture , e );
266+ disposable = () -> {};
267+ break ;
268+ }
269+ }
270+
271+ disposable = postClient .post ()
227272 .uri (uri )
228- .send (ByteBufFlux .fromInbound (Mono .just (postBytes )))
273+ .send (ByteBufFlux .fromInbound (Mono .just (bodyToSend )))
229274 .responseSingle ((response , bytes ) -> bytes .asInputStream ().zipWith (Mono .just (response )))
230275 // Request timeout
231276 .timeout (java .time .Duration .of (requestTimeout .toMillis (), MILLIS ))
@@ -303,6 +348,7 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
303348 }
304349
305350 long contentLength = 0 ;
351+ String contentEncoding = null ;
306352 // Iterate over the headers
307353 for (String name : headers .names ()) {
308354 if (name .equalsIgnoreCase (CONTENT_LENGTH_HEADER_NAME .toString ())) {
@@ -313,6 +359,9 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
313359 else if (name .equalsIgnoreCase (CONTENT_TYPE_HEADER_NAME .toString ())) {
314360 responseHeaders .put (CONTENT_TYPE_HEADER_NAME , headers .get (name ));
315361 }
362+ else if (name .equalsIgnoreCase (CONTENT_ENCODING_HEADER_NAME .toString ())) {
363+ contentEncoding = headers .get (name );
364+ }
316365 else {
317366 responseHeaders .put (HeaderName .of (name ), headers .get (name ));
318367 }
@@ -323,7 +372,21 @@ else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) {
323372 return ;
324373 }
325374
375+ final InputStream [] streamHolder = new InputStream [1 ];
376+ streamHolder [0 ] = inputStream ;
326377 try {
378+ if (contentEncoding != null && !contentEncoding .equalsIgnoreCase ("identity" )) {
379+ if (contentEncoding .equalsIgnoreCase ("zstd" )) {
380+ streamHolder [0 ] = new ZstdInputStream (inputStream );
381+ }
382+ else if (contentEncoding .equalsIgnoreCase ("gzip" )) {
383+ streamHolder [0 ] = new GZIPInputStream (inputStream );
384+ }
385+ else {
386+ throw new RuntimeException (format ("Unsupported Content-Encoding: %s. Supported: zstd, gzip." , contentEncoding ));
387+ }
388+ }
389+
327390 long finalContentLength = contentLength ;
328391 Object a = responseHandler .handle (null , new Response ()
329392 {
@@ -349,19 +412,21 @@ public long getBytesRead()
349412 public InputStream getInputStream ()
350413 throws IOException
351414 {
352- return inputStream ;
415+ return streamHolder [ 0 ] ;
353416 }
354417 });
355418 // closing it here to prevent memory leak of bytebuf
356- inputStream .close ();
419+ if (streamHolder [0 ] != null ) {
420+ streamHolder [0 ].close ();
421+ }
357422 listenableFuture .set (a );
358423 }
359424 catch (Exception e ) {
360425 listenableFuture .setException (e );
361426 }
362427 finally {
363428 try {
364- inputStream .close ();
429+ streamHolder [ 0 ] .close ();
365430 }
366431 catch (IOException e ) {
367432 log .warn (e , "Failed to close input stream" );
0 commit comments