Skip to content

Commit 7dc67c8

Browse files
authored
Fix @httpChecksum in the orchestrator implementation (#2785)
This PR fixes codegen for the Smithy `@httpChecksum` trait when generating in orchestrator mode, and also fixes an issue in the unit tests where the wrong body was being tested to be retryable. The request body should be retryable rather than the response body. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._
1 parent 9630892 commit 7dc67c8

File tree

10 files changed

+770
-104
lines changed

10 files changed

+770
-104
lines changed

aws/rust-runtime/aws-inlineable/src/http_body_checksum.rs renamed to aws/rust-runtime/aws-inlineable/src/http_body_checksum_middleware.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,7 @@ fn is_part_level_checksum(checksum: &str) -> bool {
248248

249249
#[cfg(test)]
250250
mod tests {
251-
use super::wrap_body_with_checksum_validator;
252-
use crate::http_body_checksum::is_part_level_checksum;
251+
use super::{is_part_level_checksum, wrap_body_with_checksum_validator};
253252
use aws_smithy_checksums::ChecksumAlgorithm;
254253
use aws_smithy_http::body::SdkBody;
255254
use aws_smithy_http::byte_stream::ByteStream;
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#![allow(dead_code)]
7+
8+
//! Interceptor for handling Smithy `@httpChecksum` request checksumming with AWS SigV4
9+
10+
use aws_http::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
11+
use aws_runtime::auth::sigv4::SigV4OperationSigningConfig;
12+
use aws_sigv4::http_request::SignableBody;
13+
use aws_smithy_checksums::ChecksumAlgorithm;
14+
use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
15+
use aws_smithy_http::body::{BoxBody, SdkBody};
16+
use aws_smithy_http::operation::error::BuildError;
17+
use aws_smithy_runtime_api::client::interceptors::context::Input;
18+
use aws_smithy_runtime_api::client::interceptors::{
19+
BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut, BoxError,
20+
Interceptor,
21+
};
22+
use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
23+
use http::HeaderValue;
24+
use http_body::Body;
25+
use std::{fmt, mem};
26+
27+
/// Errors related to constructing checksum-validated HTTP requests
28+
#[derive(Debug)]
29+
pub(crate) enum Error {
30+
/// Only request bodies with a known size can be checksum validated
31+
UnsizedRequestBody,
32+
ChecksumHeadersAreUnsupportedForStreamingBody,
33+
}
34+
35+
impl fmt::Display for Error {
36+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37+
match self {
38+
Self::UnsizedRequestBody => write!(
39+
f,
40+
"Only request bodies with a known size can be checksum validated."
41+
),
42+
Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
43+
f,
44+
"Checksum header insertion is only supported for non-streaming HTTP bodies. \
45+
To checksum validate a streaming body, the checksums must be sent as trailers."
46+
),
47+
}
48+
}
49+
}
50+
51+
impl std::error::Error for Error {}
52+
53+
#[derive(Debug)]
54+
struct RequestChecksumInterceptorState {
55+
checksum_algorithm: Option<ChecksumAlgorithm>,
56+
}
57+
impl Storable for RequestChecksumInterceptorState {
58+
type Storer = StoreReplace<Self>;
59+
}
60+
61+
pub(crate) struct RequestChecksumInterceptor<AP> {
62+
algorithm_provider: AP,
63+
}
64+
65+
impl<AP> fmt::Debug for RequestChecksumInterceptor<AP> {
66+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67+
f.debug_struct("RequestChecksumInterceptor").finish()
68+
}
69+
}
70+
71+
impl<AP> RequestChecksumInterceptor<AP> {
72+
pub(crate) fn new(algorithm_provider: AP) -> Self {
73+
Self { algorithm_provider }
74+
}
75+
}
76+
77+
impl<AP> Interceptor for RequestChecksumInterceptor<AP>
78+
where
79+
AP: Fn(&Input) -> Result<Option<ChecksumAlgorithm>, BoxError>,
80+
{
81+
fn read_before_serialization(
82+
&self,
83+
context: &BeforeSerializationInterceptorContextRef<'_>,
84+
cfg: &mut ConfigBag,
85+
) -> Result<(), BoxError> {
86+
let checksum_algorithm = (self.algorithm_provider)(context.input())?;
87+
88+
let mut layer = Layer::new("RequestChecksumInterceptor");
89+
layer.store_put(RequestChecksumInterceptorState { checksum_algorithm });
90+
cfg.push_layer(layer);
91+
92+
Ok(())
93+
}
94+
95+
/// Calculate a checksum and modify the request to include the checksum as a header
96+
/// (for in-memory request bodies) or a trailer (for streaming request bodies).
97+
/// Streaming bodies must be sized or this will return an error.
98+
fn modify_before_retry_loop(
99+
&self,
100+
context: &mut BeforeTransmitInterceptorContextMut<'_>,
101+
cfg: &mut ConfigBag,
102+
) -> Result<(), BoxError> {
103+
let state = cfg
104+
.load::<RequestChecksumInterceptorState>()
105+
.expect("set in `read_before_serialization`");
106+
107+
if let Some(checksum_algorithm) = state.checksum_algorithm {
108+
let request = context.request_mut();
109+
add_checksum_for_request_body(request, checksum_algorithm, cfg)?;
110+
}
111+
112+
Ok(())
113+
}
114+
}
115+
116+
fn add_checksum_for_request_body(
117+
request: &mut http::request::Request<SdkBody>,
118+
checksum_algorithm: ChecksumAlgorithm,
119+
cfg: &mut ConfigBag,
120+
) -> Result<(), BoxError> {
121+
match request.body().bytes() {
122+
// Body is in-memory: read it and insert the checksum as a header.
123+
Some(data) => {
124+
tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
125+
let mut checksum = checksum_algorithm.into_impl();
126+
checksum.update(data);
127+
128+
request
129+
.headers_mut()
130+
.insert(checksum.header_name(), checksum.header_value());
131+
}
132+
// Body is streaming: wrap the body so it will emit a checksum as a trailer.
133+
None => {
134+
tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
135+
if let Some(mut signing_config) = cfg.get::<SigV4OperationSigningConfig>().cloned() {
136+
signing_config.signing_options.payload_override =
137+
Some(SignableBody::StreamingUnsignedPayloadTrailer);
138+
139+
let mut layer = Layer::new("http_body_checksum_sigv4_payload_override");
140+
layer.put(signing_config);
141+
cfg.push_layer(layer);
142+
}
143+
wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm)?;
144+
}
145+
}
146+
Ok(())
147+
}
148+
149+
fn wrap_streaming_request_body_in_checksum_calculating_body(
150+
request: &mut http::request::Request<SdkBody>,
151+
checksum_algorithm: ChecksumAlgorithm,
152+
) -> Result<(), BuildError> {
153+
let original_body_size = request
154+
.body()
155+
.size_hint()
156+
.exact()
157+
.ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
158+
159+
let mut body = {
160+
let body = mem::replace(request.body_mut(), SdkBody::taken());
161+
162+
body.map(move |body| {
163+
let checksum = checksum_algorithm.into_impl();
164+
let trailer_len = HttpChecksum::size(checksum.as_ref());
165+
let body = calculate::ChecksumBody::new(body, checksum);
166+
let aws_chunked_body_options =
167+
AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
168+
169+
let body = AwsChunkedBody::new(body, aws_chunked_body_options);
170+
171+
SdkBody::from_dyn(BoxBody::new(body))
172+
})
173+
};
174+
175+
let encoded_content_length = body
176+
.size_hint()
177+
.exact()
178+
.ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
179+
180+
let headers = request.headers_mut();
181+
182+
headers.insert(
183+
http::header::HeaderName::from_static("x-amz-trailer"),
184+
// Convert into a `HeaderName` and then into a `HeaderValue`
185+
http::header::HeaderName::from(checksum_algorithm).into(),
186+
);
187+
188+
headers.insert(
189+
http::header::CONTENT_LENGTH,
190+
HeaderValue::from(encoded_content_length),
191+
);
192+
headers.insert(
193+
http::header::HeaderName::from_static("x-amz-decoded-content-length"),
194+
HeaderValue::from(original_body_size),
195+
);
196+
headers.insert(
197+
http::header::CONTENT_ENCODING,
198+
HeaderValue::from_str(aws_http::content_encoding::header_value::AWS_CHUNKED)
199+
.map_err(BuildError::other)
200+
.expect("\"aws-chunked\" will always be a valid HeaderValue"),
201+
);
202+
203+
mem::swap(request.body_mut(), &mut body);
204+
205+
Ok(())
206+
}
207+
208+
#[cfg(test)]
209+
mod tests {
210+
use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
211+
use aws_smithy_checksums::ChecksumAlgorithm;
212+
use aws_smithy_http::body::SdkBody;
213+
use aws_smithy_http::byte_stream::ByteStream;
214+
use aws_smithy_types::base64;
215+
use bytes::BytesMut;
216+
use http_body::Body;
217+
use tempfile::NamedTempFile;
218+
219+
#[tokio::test]
220+
async fn test_checksum_body_is_retryable() {
221+
let input_text = "Hello world";
222+
let chunk_len_hex = format!("{:X}", input_text.len());
223+
let mut request = http::Request::builder()
224+
.body(SdkBody::retryable(move || SdkBody::from(input_text)))
225+
.unwrap();
226+
227+
// ensure original SdkBody is retryable
228+
assert!(request.body().try_clone().is_some());
229+
230+
let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
231+
wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm)
232+
.unwrap();
233+
234+
// ensure wrapped SdkBody is retryable
235+
let mut body = request.body().try_clone().expect("body is retryable");
236+
237+
let mut body_data = BytesMut::new();
238+
loop {
239+
match body.data().await {
240+
Some(data) => body_data.extend_from_slice(&data.unwrap()),
241+
None => break,
242+
}
243+
}
244+
let body = std::str::from_utf8(&body_data).unwrap();
245+
assert_eq!(
246+
format!(
247+
"{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"
248+
),
249+
body
250+
);
251+
}
252+
253+
#[tokio::test]
254+
async fn test_checksum_body_from_file_is_retryable() {
255+
use std::io::Write;
256+
let mut file = NamedTempFile::new().unwrap();
257+
let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
258+
259+
let mut crc32c_checksum = checksum_algorithm.into_impl();
260+
for i in 0..10000 {
261+
let line = format!("This is a large file created for testing purposes {}", i);
262+
file.as_file_mut().write_all(line.as_bytes()).unwrap();
263+
crc32c_checksum.update(line.as_bytes());
264+
}
265+
let crc32c_checksum = crc32c_checksum.finalize();
266+
267+
let mut request = http::Request::builder()
268+
.body(
269+
ByteStream::read_from()
270+
.path(&file)
271+
.buffer_size(1024)
272+
.build()
273+
.await
274+
.unwrap()
275+
.into_inner(),
276+
)
277+
.unwrap();
278+
279+
// ensure original SdkBody is retryable
280+
assert!(request.body().try_clone().is_some());
281+
282+
wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm)
283+
.unwrap();
284+
285+
// ensure wrapped SdkBody is retryable
286+
let mut body = request.body().try_clone().expect("body is retryable");
287+
288+
let mut body_data = BytesMut::new();
289+
loop {
290+
match body.data().await {
291+
Some(data) => body_data.extend_from_slice(&data.unwrap()),
292+
None => break,
293+
}
294+
}
295+
let body = std::str::from_utf8(&body_data).unwrap();
296+
let expected_checksum = base64::encode(&crc32c_checksum);
297+
let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
298+
assert!(
299+
body.ends_with(&expected),
300+
"expected {body} to end with '{expected}'"
301+
);
302+
}
303+
}

0 commit comments

Comments
 (0)