Skip to content

Commit b27c6de

Browse files
committed
feat: Add suppport to compress the task update request header and body and taskinfo response
1 parent 84af5a1 commit b27c6de

File tree

7 files changed

+239
-12
lines changed

7 files changed

+239
-12
lines changed

presto-main/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,11 @@
475475
<artifactId>postgresql</artifactId>
476476
<scope>test</scope>
477477
</dependency>
478+
479+
<dependency>
480+
<groupId>com.github.luben</groupId>
481+
<artifactId>zstd-jni</artifactId>
482+
</dependency>
478483
</dependencies>
479484

480485
<build>

presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
import com.facebook.presto.resourcemanager.ResourceManagerConfig;
153153
import com.facebook.presto.resourcemanager.ResourceManagerInconsistentException;
154154
import com.facebook.presto.resourcemanager.ResourceManagerResourceGroupService;
155+
import com.facebook.presto.server.remotetask.DecompressionFilter;
155156
import com.facebook.presto.server.remotetask.HttpLocationFactory;
156157
import com.facebook.presto.server.remotetask.ReactorNettyHttpClientConfig;
157158
import com.facebook.presto.server.thrift.FixedAddressSelector;
@@ -437,6 +438,10 @@ else if (serverConfig.isCoordinator()) {
437438
// task execution
438439
jaxrsBinder(binder).bind(TaskResource.class);
439440
jaxrsBinder(binder).bind(ThriftTaskUpdateRequestBodyReader.class);
441+
install(installModuleIf(
442+
ReactorNettyHttpClientConfig.class,
443+
ReactorNettyHttpClientConfig::isHttp2CompressionEnabled,
444+
moduleBinder -> jaxrsBinder(moduleBinder).bind(DecompressionFilter.class)));
440445

441446
newExporter(binder).export(TaskResource.class).withGeneratedName();
442447
jaxrsBinder(binder).bind(TaskExecutorResource.class);
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.server.remotetask;
15+
16+
import com.facebook.airlift.log.Logger;
17+
import com.facebook.presto.spi.PrestoException;
18+
import com.github.luben.zstd.ZstdInputStream;
19+
import jakarta.annotation.Priority;
20+
import jakarta.ws.rs.Priorities;
21+
import jakarta.ws.rs.container.ContainerRequestContext;
22+
import jakarta.ws.rs.container.ContainerRequestFilter;
23+
import jakarta.ws.rs.ext.Provider;
24+
25+
import java.io.IOException;
26+
import java.io.InputStream;
27+
28+
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
29+
import static java.lang.String.format;
30+
31+
@Provider
32+
@Priority(Priorities.ENTITY_CODER)
33+
public class DecompressionFilter
34+
implements ContainerRequestFilter
35+
{
36+
private static final Logger log = Logger.get(DecompressionFilter.class);
37+
38+
@Override
39+
public void filter(ContainerRequestContext containerRequestContext)
40+
throws IOException
41+
{
42+
String contentEncoding = containerRequestContext.getHeaderString("Content-Encoding");
43+
44+
if (contentEncoding != null && !contentEncoding.equalsIgnoreCase("identity")) {
45+
InputStream originalStream = containerRequestContext.getEntityStream();
46+
InputStream decompressedStream;
47+
48+
if (contentEncoding.equalsIgnoreCase("zstd")) {
49+
decompressedStream = new ZstdInputStream(originalStream);
50+
}
51+
else {
52+
throw new PrestoException(NOT_SUPPORTED, format("Unsupported Content-Encoding: '%s'. Only zstd compression is supported.", contentEncoding));
53+
}
54+
55+
containerRequestContext.setEntityStream(decompressedStream);
56+
containerRequestContext.getHeaders().remove("Content-Encoding");
57+
}
58+
}
59+
}

presto-main/src/main/java/com/facebook/presto/server/remotetask/ReactorNettyHttpClient.java

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121
import com.facebook.airlift.http.client.StaticBodyGenerator;
2222
import com.facebook.airlift.log.Logger;
2323
import com.facebook.airlift.units.Duration;
24+
import com.github.luben.zstd.ZstdInputStream;
25+
import com.github.luben.zstd.ZstdOutputStreamNoFinalizer;
2426
import com.google.common.base.Splitter;
2527
import com.google.common.collect.ArrayListMultimap;
2628
import com.google.common.collect.ListMultimap;
2729
import com.google.common.util.concurrent.SettableFuture;
2830
import com.google.inject.Inject;
2931
import io.netty.channel.ChannelOption;
32+
import io.netty.channel.WriteBufferWaterMark;
3033
import io.netty.channel.epoll.Epoll;
3134
import io.netty.handler.codec.http.HttpHeaders;
3235
import io.netty.handler.ssl.ApplicationProtocolConfig;
@@ -44,6 +47,7 @@
4447
import reactor.netty.resources.ConnectionProvider;
4548
import reactor.netty.resources.LoopResources;
4649

50+
import java.io.ByteArrayOutputStream;
4751
import java.io.Closeable;
4852
import java.io.File;
4953
import java.io.IOException;
@@ -62,6 +66,7 @@
6266
import java.util.concurrent.TimeUnit;
6367
import java.util.concurrent.TimeoutException;
6468
import java.util.function.Function;
69+
import java.util.zip.GZIPInputStream;
6570

6671
import static com.facebook.airlift.security.pem.PemReader.loadPrivateKey;
6772
import static com.facebook.airlift.security.pem.PemReader.readCertificateChain;
@@ -84,17 +89,25 @@ public class ReactorNettyHttpClient
8489
private static final Logger log = Logger.get(ReactorNettyHttpClient.class);
8590
private static final HeaderName CONTENT_TYPE_HEADER_NAME = HeaderName.of("Content-Type");
8691
private static final HeaderName CONTENT_LENGTH_HEADER_NAME = HeaderName.of("Content-Length");
92+
private static final HeaderName CONTENT_ENCODING_HEADER_NAME = HeaderName.of("Content-Encoding");
93+
private static final HeaderName ACCEPT_ENCODING_HEADER_NAME = HeaderName.of("Accept-Encoding");
8794

8895
private final Duration requestTimeout;
8996
private HttpClient httpClient;
9097
private final HttpClientConnectionPoolStats connectionPoolStats;
9198
private final HttpClientStats httpClientStats;
99+
private final boolean isHttp2CompressionEnabled;
100+
private final int payloadSizeThreshold;
101+
private final double compressionSavingThreshold;
92102

93103
@Inject
94104
public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientConnectionPoolStats connectionPoolStats, HttpClientStats httpClientStats)
95105
{
96106
this.connectionPoolStats = connectionPoolStats;
97107
this.httpClientStats = httpClientStats;
108+
this.isHttp2CompressionEnabled = config.isHttp2CompressionEnabled();
109+
this.payloadSizeThreshold = config.getPayloadSizeThreshold();
110+
this.compressionSavingThreshold = config.getCompressionSavingThreshold();
98111
SslContext sslContext = null;
99112
if (config.isHttpsEnabled()) {
100113
try {
@@ -114,11 +127,11 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
114127
if (os.toLowerCase(Locale.ENGLISH).contains("linux")) {
115128
// Make sure Open ssl is available for linux deployments
116129
if (!OpenSsl.isAvailable()) {
117-
throw new UnsupportedOperationException(format("OpenSsl is not unavailable. Stacktrace: %s", Arrays.toString(OpenSsl.unavailabilityCause().getStackTrace()).replace(',', '\n')));
130+
throw new UnsupportedOperationException(format("OpenSsl is not available. Stacktrace: %s", Arrays.toString(OpenSsl.unavailabilityCause().getStackTrace()).replace(',', '\n')));
118131
}
119132
// Make sure epoll threads are used for linux deployments
120133
if (!Epoll.isAvailable()) {
121-
throw new UnsupportedOperationException(format("Epoll is not unavailable. Stacktrace: %s", Arrays.toString(Epoll.unavailabilityCause().getStackTrace()).replace(',', '\n')));
134+
throw new UnsupportedOperationException(format("Epoll is not available. Stacktrace: %s", Arrays.toString(Epoll.unavailabilityCause().getStackTrace()).replace(',', '\n')));
122135
}
123136
}
124137

@@ -166,9 +179,10 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
166179

167180
// Create HTTP/2 client
168181
SslContext finalSslContext = sslContext;
182+
169183
this.httpClient = HttpClient
170-
// The custom pool is wrapped with a HttpConnectionProvider over here
171-
.create(pool)
184+
.create(pool) // The custom pool is wrapped with a HttpConnectionProvider over here
185+
.compress(false) // we will enable response compression manually
172186
.protocol(HttpProtocol.H2, HttpProtocol.HTTP11)
173187
.runOn(loopResources, true)
174188
.http2Settings(settings -> {
@@ -179,6 +193,9 @@ public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientCon
179193
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) config.getConnectTimeout().getValue())
180194
.option(ChannelOption.SO_KEEPALIVE, true)
181195
.option(ChannelOption.TCP_NODELAY, true)
196+
.option(ChannelOption.SO_SNDBUF, config.getTcpBufferSize())
197+
.option(ChannelOption.SO_RCVBUF, config.getTcpBufferSize())
198+
.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(config.getWriteBufferWaterMarkLow(), config.getWriteBufferWaterMarkHigh()))
182199
// Track HTTP client metrics
183200
.metrics(true, () -> httpClientStats, Function.identity());
184201

@@ -208,6 +225,10 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
208225
for (Map.Entry<String, String> entry : airliftRequest.getHeaders().entries()) {
209226
hdr.set(entry.getKey(), entry.getValue());
210227
}
228+
229+
if (isHttp2CompressionEnabled) {
230+
hdr.set(ACCEPT_ENCODING_HEADER_NAME.toString(), "zstd, gzip");
231+
}
211232
});
212233

213234
URI uri = airliftRequest.getUri();
@@ -223,9 +244,34 @@ public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airli
223244
break;
224245
case "POST":
225246
byte[] postBytes = ((StaticBodyGenerator) airliftRequest.getBodyGenerator()).getBody();
226-
disposable = client.post()
247+
byte[] bodyToSend = postBytes;
248+
HttpClient postClient = client;
249+
// We manually do compression for request, use zstd
250+
if (isHttp2CompressionEnabled && postBytes.length >= payloadSizeThreshold) {
251+
try {
252+
ByteArrayOutputStream baos = new ByteArrayOutputStream(postBytes.length / 2);
253+
try (ZstdOutputStreamNoFinalizer zstdOutput = new ZstdOutputStreamNoFinalizer(baos)) {
254+
zstdOutput.write(postBytes);
255+
}
256+
257+
byte[] compressedBytes = baos.toByteArray();
258+
double compressionRatio = (double) (postBytes.length - compressedBytes.length) / postBytes.length;
259+
if (compressionRatio >= compressionSavingThreshold) {
260+
bodyToSend = compressedBytes;
261+
postClient = client.headers(h -> h.set(CONTENT_ENCODING_HEADER_NAME.toString(), "zstd"));
262+
}
263+
}
264+
catch (IOException e) {
265+
log.error(e, "Fail to compress POST request body");
266+
onError(listenableFuture, e);
267+
disposable = () -> {};
268+
break;
269+
}
270+
}
271+
272+
disposable = postClient.post()
227273
.uri(uri)
228-
.send(ByteBufFlux.fromInbound(Mono.just(postBytes)))
274+
.send(ByteBufFlux.fromInbound(Mono.just(bodyToSend)))
229275
.responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response)))
230276
// Request timeout
231277
.timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS))
@@ -303,6 +349,7 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
303349
}
304350

305351
long contentLength = 0;
352+
String contentEncoding = null;
306353
// Iterate over the headers
307354
for (String name : headers.names()) {
308355
if (name.equalsIgnoreCase(CONTENT_LENGTH_HEADER_NAME.toString())) {
@@ -313,6 +360,9 @@ public void onSuccess(ResponseHandler responseHandler, InputStream inputStream,
313360
else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) {
314361
responseHeaders.put(CONTENT_TYPE_HEADER_NAME, headers.get(name));
315362
}
363+
else if (name.equalsIgnoreCase(CONTENT_ENCODING_HEADER_NAME.toString())) {
364+
contentEncoding = headers.get(name);
365+
}
316366
else {
317367
responseHeaders.put(HeaderName.of(name), headers.get(name));
318368
}
@@ -323,7 +373,21 @@ else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) {
323373
return;
324374
}
325375

376+
final InputStream[] streamHolder = new InputStream[1];
377+
streamHolder[0] = inputStream;
326378
try {
379+
if (contentEncoding != null && !contentEncoding.equalsIgnoreCase("identity")) {
380+
if (contentEncoding.equalsIgnoreCase("zstd")) {
381+
streamHolder[0] = new ZstdInputStream(inputStream);
382+
}
383+
else if (contentEncoding.equalsIgnoreCase("gzip")) {
384+
streamHolder[0] = new GZIPInputStream(inputStream);
385+
}
386+
else {
387+
throw new RuntimeException(format("Unsupported Content-Encoding: %s. Supported: zstd, gzip.", contentEncoding));
388+
}
389+
}
390+
327391
long finalContentLength = contentLength;
328392
Object a = responseHandler.handle(null, new Response()
329393
{
@@ -349,19 +413,19 @@ public long getBytesRead()
349413
public InputStream getInputStream()
350414
throws IOException
351415
{
352-
return inputStream;
416+
return streamHolder[0];
353417
}
354418
});
355419
// closing it here to prevent memory leak of bytebuf
356-
inputStream.close();
420+
streamHolder[0].close();
357421
listenableFuture.set(a);
358422
}
359423
catch (Exception e) {
360424
listenableFuture.setException(e);
361425
}
362426
finally {
363427
try {
364-
inputStream.close();
428+
streamHolder[0].close();
365429
}
366430
catch (IOException e) {
367431
log.warn(e, "Failed to close input stream");

0 commit comments

Comments
 (0)