Skip to content

Commit 1a92500

Browse files
committed
Always enable expect 100-continue on cross region
When making cross region streaming calls on S3, we need to have the 'Expect: 100-continue' header set to avoid I/O errors from sending the body unconditionally to S3 resulting in responding with a 3xx and termining the connection while the client is still writing the object.
1 parent 43865ad commit 1a92500

4 files changed

Lines changed: 143 additions & 0 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "bugfix",
3+
"category": "Amazon S3",
4+
"contributor": "",
5+
"description": "Always set 'Expect: 100-continue' when usingi PUT operations across operations; this enables the correct redirect behavior when the initial request goes to an incorrect region."
6+
}

services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Configuration.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ public boolean chunkedEncodingEnabled() {
248248
* By default, the SDK sends the {@code Expect: 100-continue} header for these operations, allowing the server to
249249
* reject the request before the client sends the full payload. Setting this to {@code false} disables this behavior.
250250
* <p>
251+
* If enabling cross region access on the client, this setting has no effect (and by extension neither does
252+
* {@link #expectContinueThresholdInBytes()} as the client needs to set this header for correct redirect behavior.
253+
* <p>
251254
* <b>Note:</b> When using the {@code ApacheHttpClient} (Apache 4), the Apache 4 client also independently adds the
252255
* {@code Expect: 100-continue} header by default via its own {@code expectContinueEnabled} setting. To fully
253256
* suppress the header on the wire, you must also disable it on the Apache4 HTTP client builder using

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/StreamingRequestInterceptor.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
2323
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
2424
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
25+
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
2526
import software.amazon.awssdk.http.SdkHttpRequest;
2627
import software.amazon.awssdk.services.s3.S3Configuration;
28+
import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams;
2729
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
2830
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
31+
import software.amazon.awssdk.utils.AttributeMap;
2932

3033
/**
3134
* Interceptor to add an 'Expect: 100-continue' header to the HTTP Request if it represents a PUT Object or Upload Part
@@ -55,6 +58,13 @@ private boolean shouldAddExpectContinueHeader(Context.ModifyHttpRequest context,
5558
return false;
5659
}
5760

61+
// The header is necessary for cross region PUT because sending the body unconditionally to the wrong region where S3
62+
// will respond with a 3xx and close the connection will cause I/O errors rather than resulting in the client retrying
63+
// based on the region given in the 3xx response.
64+
if (isCrossRegionAccessEnabled(executionAttributes)) {
65+
return true;
66+
}
67+
5868
S3Configuration s3Config = getS3Configuration(executionAttributes);
5969

6070
if (s3Config != null && !s3Config.expectContinueEnabled()) {
@@ -88,4 +98,12 @@ private Optional<String> getContentLengthHeader(SdkHttpRequest httpRequest) {
8898
? decodedLength
8999
: httpRequest.firstMatchingHeader(CONTENT_LENGTH_HEADER);
90100
}
101+
102+
private boolean isCrossRegionAccessEnabled(ExecutionAttributes executionAttributes) {
103+
Optional<AttributeMap> ctxParams = executionAttributes.getOptionalAttribute(
104+
SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS);
105+
106+
return ctxParams.map(p -> Boolean.TRUE.equals(p.get(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED)))
107+
.orElse(false);
108+
}
91109
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.s3.internal.crossregion;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
20+
import static org.mockito.ArgumentMatchers.any;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.verify;
23+
import static org.mockito.Mockito.when;
24+
25+
import java.util.concurrent.CompletableFuture;
26+
import org.junit.jupiter.api.BeforeEach;
27+
import org.junit.jupiter.params.ParameterizedTest;
28+
import org.junit.jupiter.params.provider.CsvSource;
29+
import org.mockito.ArgumentCaptor;
30+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
31+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
32+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
33+
import software.amazon.awssdk.core.async.AsyncRequestBody;
34+
import software.amazon.awssdk.core.sync.RequestBody;
35+
import software.amazon.awssdk.http.HttpExecuteRequest;
36+
import software.amazon.awssdk.http.SdkHttpClient;
37+
import software.amazon.awssdk.http.SdkHttpRequest;
38+
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
39+
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
40+
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
41+
import software.amazon.awssdk.regions.Region;
42+
import software.amazon.awssdk.services.s3.S3AsyncClient;
43+
import software.amazon.awssdk.services.s3.S3Client;
44+
45+
public class Expect100ContinueTest {
46+
private static final AwsCredentialsProvider TEST_CREDS = StaticCredentialsProvider.create(
47+
AwsBasicCredentials.create("akid", "skid"));
48+
private SdkHttpClient mockSyncHttp;
49+
private SdkAsyncHttpClient mockAsyncHttp;
50+
51+
@BeforeEach
52+
void setup() {
53+
mockSyncHttp = mock(SdkHttpClient.class);
54+
when(mockSyncHttp.prepareRequest(any(HttpExecuteRequest.class))).thenThrow(new RuntimeException("expect 100 continue"));
55+
56+
mockAsyncHttp = mock(SdkAsyncHttpClient.class);
57+
CompletableFuture cf = new CompletableFuture();
58+
cf.completeExceptionally(new RuntimeException("expect 100 continue"));
59+
when(mockAsyncHttp.execute(any(AsyncExecuteRequest.class))).thenAnswer(i -> {
60+
AsyncExecuteRequest req = i.getArgument(0);
61+
SdkAsyncHttpResponseHandler handler = req.responseHandler();
62+
handler.onError(new RuntimeException("expect 100 continue"));
63+
return CompletableFuture.completedFuture(null);
64+
});
65+
}
66+
67+
@ParameterizedTest(name = "expect 100-continue enabled = {0}")
68+
@CsvSource({"true", "false"})
69+
void sync_alwaysAdds(boolean enabled) {
70+
71+
try (S3Client s3 = S3Client.builder()
72+
.httpClient(mockSyncHttp)
73+
.region(Region.US_WEST_2)
74+
.credentialsProvider(TEST_CREDS)
75+
.crossRegionAccessEnabled(true)
76+
.serviceConfiguration(o -> o.expectContinueEnabled(enabled)
77+
.expectContinueThresholdInBytes(1L))
78+
.build()) {
79+
RequestBody requestBody = RequestBody.fromBytes(new byte[16]);
80+
assertThatThrownBy(() -> s3.putObject(o -> o.bucket("bucket").key("key"), requestBody))
81+
.hasMessage("expect 100 continue");
82+
83+
ArgumentCaptor<HttpExecuteRequest> requestCaptor = ArgumentCaptor.forClass(HttpExecuteRequest.class);
84+
85+
verify(mockSyncHttp).prepareRequest(requestCaptor.capture());
86+
assertHasExpect100Continue(requestCaptor.getValue().httpRequest());
87+
}
88+
}
89+
90+
@ParameterizedTest(name = "expect 100-continue enabled = {0}")
91+
@CsvSource({"true", "false"})
92+
void async_alwaysAdds(boolean enabled) {
93+
try (S3AsyncClient s3 = S3AsyncClient.builder()
94+
.httpClient(mockAsyncHttp)
95+
.region(Region.US_WEST_2)
96+
.credentialsProvider(TEST_CREDS)
97+
.crossRegionAccessEnabled(true)
98+
.serviceConfiguration(o -> o.expectContinueEnabled(enabled)
99+
.expectContinueThresholdInBytes(1L))
100+
.build()) {
101+
AsyncRequestBody requestBody = AsyncRequestBody.fromBytes(new byte[16]);
102+
assertThatThrownBy(s3.putObject(o -> o.bucket("bucket").key("key"), requestBody)::join)
103+
.hasMessageContaining("expect 100 continue");
104+
105+
ArgumentCaptor<AsyncExecuteRequest> requestCaptor = ArgumentCaptor.forClass(AsyncExecuteRequest.class);
106+
107+
verify(mockAsyncHttp).execute(requestCaptor.capture());
108+
assertHasExpect100Continue(requestCaptor.getValue().request());
109+
}
110+
}
111+
112+
private static void assertHasExpect100Continue(SdkHttpRequest httpRequest) {
113+
assertThat(httpRequest.firstMatchingHeader("Expect"))
114+
.hasValueSatisfying(v -> assertThat(v).isEqualToIgnoringCase("100-continue"));
115+
}
116+
}

0 commit comments

Comments
 (0)