Skip to content

Commit 09c473c

Browse files
authored
fix(bottlecap): fix how Telemetry API stream is being read (#259)
* add `from_stream` to `HttpRequestParser` read properly from the stream so we can read every byte, retry 3 times in case theres no data * add unit test also some integration tests, moving away from mocking `TcpStream`, since signature changed * add constants, and refactor tests * add another unit test for invalid data * update algorithm * typo
1 parent 47617ad commit 09c473c

File tree

1 file changed

+83
-44
lines changed

1 file changed

+83
-44
lines changed

bottlecap/src/telemetry/listener.rs

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,62 @@ pub struct HttpRequestParser {
1818

1919
const CR: u8 = b'\r';
2020
const LR: u8 = b'\n';
21+
/// It is guaranteed that the headers will be less than 256 bytes.
22+
const HEADERS_BUFFER_SIZE: usize = 256;
2123

2224
impl HttpRequestParser {
23-
/// Create a `HttpRequestParser` from passed `buf`
25+
/// Create a `HttpRequestParser` from a `TcpStream`
2426
///
2527
/// # Errors
2628
///
27-
/// Function will fail if parsing of headers or body in `buf` fail.
28-
pub fn from_buf(buf: &[u8]) -> Result<HttpRequestParser, Box<dyn Error>> {
29+
/// Function will error if the stream cannot be read from.
30+
///
31+
/// It will also error if the headers cannot be parsed.
32+
///
33+
/// Or if the body cannot be parsed.
34+
pub fn from_stream(mut stream: &TcpStream) -> Result<HttpRequestParser, Box<dyn Error>> {
35+
stream.set_nonblocking(true)?;
2936
let mut parser = HttpRequestParser {
3037
headers: HashMap::new(),
3138
body: String::new(),
3239
};
3340

34-
let body_start_index = parser.parse_headers(buf)?;
35-
parser.parse_body(buf, body_start_index)?;
41+
let mut headers_buf = [0u8; HEADERS_BUFFER_SIZE];
42+
loop {
43+
match stream.read(&mut headers_buf) {
44+
Ok(_) => {}
45+
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
46+
continue;
47+
}
48+
Err(e) => {
49+
error!("Error reading from stream: {}", e);
50+
}
51+
}
52+
53+
let _ = parser.parse_headers(&headers_buf);
54+
if parser.headers.contains_key("content-length") {
55+
break;
56+
}
57+
}
58+
59+
let body_start_index = parser.parse_headers(&headers_buf)?;
60+
let content_length = parser
61+
.headers
62+
.get("content-length")
63+
.expect("infallible")
64+
.parse::<usize>()?;
65+
let body_bytes_read = headers_buf.len() - body_start_index;
66+
let missing_body_lenght = content_length - body_bytes_read;
67+
let mut body_buf = vec![0u8; missing_body_lenght];
68+
69+
stream.read_exact(&mut body_buf)?;
70+
71+
let total_bytes_read = headers_buf.len() + missing_body_lenght;
72+
let mut buf = vec![0u8; total_bytes_read];
73+
buf[..headers_buf.len()].copy_from_slice(&headers_buf);
74+
buf[headers_buf.len()..].copy_from_slice(&body_buf);
75+
76+
parser.parse_body(&buf, body_start_index)?;
3677

3778
Ok(parser)
3879
}
@@ -133,7 +174,6 @@ impl TelemetryListener {
133174
) -> Result<TelemetryListener, Box<dyn Error>> {
134175
let addr = format!("{}:{}", &config.host, &config.port);
135176
let listener = TcpListener::bind(addr)?;
136-
let buf: [u8; 262_144] = [0; 256 * 1024]; // Using the default limit from AWS
137177

138178
let join_handle = std::thread::spawn(move || {
139179
debug!("Initializing Telemetry Listener");
@@ -143,9 +183,9 @@ impl TelemetryListener {
143183
debug!("Received a Telemetry API connection");
144184

145185
let cloned_event_bus = event_bus.clone();
146-
if let Ok(mut stream) = stream {
186+
if let Ok(stream) = stream {
147187
std::thread::spawn(move || {
148-
let r = Self::handle_stream(&mut stream, buf, cloned_event_bus);
188+
let r = Self::handle_stream(&stream, cloned_event_bus);
149189
if let Err(e) = Self::acknowledge_request(stream, r) {
150190
error!("Error acknowledging Telemetry request: {:?}", e);
151191
}
@@ -161,15 +201,10 @@ impl TelemetryListener {
161201
}
162202

163203
fn handle_stream(
164-
stream: &mut impl Read,
165-
mut buf: [u8; 262_144],
204+
stream: &TcpStream,
166205
event_bus: SyncSender<events::Event>,
167206
) -> Result<(), Box<dyn Error>> {
168-
// Read into buffer
169-
#![allow(clippy::unused_io_amount)]
170-
stream.read(&mut buf)?;
171-
172-
let p = HttpRequestParser::from_buf(&buf)?;
207+
let p = HttpRequestParser::from_stream(stream)?;
173208
let telemetry_events: Vec<TelemetryEvent> = serde_json::from_str(&p.body)?;
174209
for event in telemetry_events {
175210
if let Err(e) = event_bus.send(events::Event::Telemetry(event)) {
@@ -210,6 +245,8 @@ impl TelemetryListener {
210245

211246
#[cfg(test)]
212247
mod tests {
248+
use std::thread;
249+
213250
use chrono::DateTime;
214251

215252
use crate::telemetry::events::{InitPhase, InitType, TelemetryRecord};
@@ -287,27 +324,29 @@ mod tests {
287324
assert_eq!(parser.body, "Hello, World!".to_string());
288325
}
289326

290-
struct MockTcpStream {
291-
data: Vec<u8>,
292-
}
327+
fn get_stream(data: Vec<u8>) -> TcpStream {
328+
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
329+
let addr = listener.local_addr().unwrap();
293330

294-
impl Read for MockTcpStream {
295-
fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
296-
let len = std::cmp::min(buf.len(), self.data.len());
297-
buf.write_all(&self.data[..len])?;
298-
self.data = self.data.split_off(len);
299-
Ok(len)
300-
}
331+
let (tx, rx) = std::sync::mpsc::channel();
332+
thread::spawn(move || {
333+
let (mut stream, _) = listener.accept().unwrap();
334+
stream.write_all(&data).unwrap();
335+
tx.send(()).unwrap(); // Signal that the request has been sent
336+
});
337+
338+
let stream = TcpStream::connect(addr).unwrap();
339+
rx.recv().unwrap(); // Wait for the signal from the spawned thread
340+
341+
stream
301342
}
302343

303344
#[test]
304345
fn test_handle_stream() {
305-
let mut stream = MockTcpStream {
306-
data: "POST /path HTTP/1.1\r\nContent-Length: 335\r\nHeader1: Value1\r\n\r\n[{\"time\":\"2024-04-25T17:35:59.944Z\",\"type\":\"platform.initStart\",\"record\":{\"initializationType\":\"on-demand\",\"phase\":\"init\",\"runtimeVersion\":\"nodejs:20.v22\",\"runtimeVersionArn\":\"arn:aws:lambda:us-east-1::runtime:da57c20c4b965d5b75540f6865a35fc8030358e33ec44ecfed33e90901a27a72\",\"functionName\":\"hello-world\",\"functionVersion\":\"$LATEST\"}}]".to_string().into_bytes(),
307-
};
346+
let stream = get_stream(b"POST /path HTTP/1.1\r\nContent-Length: 335\r\nHeader1: Value1\r\n\r\n[{\"time\":\"2024-04-25T17:35:59.944Z\",\"type\":\"platform.initStart\",\"record\":{\"initializationType\":\"on-demand\",\"phase\":\"init\",\"runtimeVersion\":\"nodejs:20.v22\",\"runtimeVersionArn\":\"arn:aws:lambda:us-east-1::runtime:da57c20c4b965d5b75540f6865a35fc8030358e33ec44ecfed33e90901a27a72\",\"functionName\":\"hello-world\",\"functionVersion\":\"$LATEST\"}}]".to_vec());
347+
308348
let (tx, rx) = std::sync::mpsc::sync_channel(3);
309-
let buf = [0; 262_144];
310-
let result = TelemetryListener::handle_stream(&mut stream, buf, tx);
349+
let result = TelemetryListener::handle_stream(&stream, tx);
311350
let event = rx.recv().expect("No events received");
312351
let telemetry_event = match event {
313352
events::Event::Telemetry(te) => te,
@@ -331,37 +370,37 @@ mod tests {
331370
#[test]
332371
#[should_panic]
333372
fn $name() {
334-
let mut stream = MockTcpStream {
335-
data: $value.to_string().into_bytes(),
336-
};
373+
let stream = get_stream($value.to_vec());
374+
337375
let (tx, _) = std::sync::mpsc::sync_channel(4);
338-
let buf = [0; 262144];
339-
TelemetryListener::handle_stream(&mut stream, buf, tx).unwrap()
376+
TelemetryListener::handle_stream(&stream, tx).unwrap()
340377
}
341378
)*
342379
}
343380
}
344381

345382
test_handle_stream_invalid_body! {
346-
invalid_json: "POST /path HTTP/1.1\r\nContent-Length: 13\r\nHeader1: Value1\r\n\r\nHello, World!",
347-
empty_json: "POST /path HTTP/1.1\r\nContent-Length: 2\r\nHeader1: Value1\r\n\r\n{}",
348-
json_array_with_empty_json: "POST /path HTTP/1.1\r\nContent-Length: 4\r\nHeader1: Value1\r\n\r\n[{}]",
383+
invalid_json: b"POST /path HTTP/1.1\r\nContent-Length: 13\r\nHeader1: Value1\r\n\r\nHello, World!",
384+
empty_json: b"POST /path HTTP/1.1\r\nContent-Length: 2\r\nHeader1: Value1\r\n\r\n{}",
385+
json_array_with_empty_json: b"POST /path HTTP/1.1\r\nContent-Length: 4\r\nHeader1: Value1\r\n\r\n[{}]",
349386

350387
}
351388

352389
#[test]
353-
fn test_from_buf() {
354-
let buf =
355-
b"GET /path HTTP/1.1\r\nContent-Length: 13\r\nHeader1: Value1\r\n\r\nHello, World!";
356-
let result = HttpRequestParser::from_buf(buf);
390+
fn test_from_stream() {
391+
let stream = get_stream(b"POST /path HTTP/1.1\r\nContent-Length: 335\r\nHeader1: Value1\r\n\r\n[{\"time\":\"2024-04-25T17:35:59.944Z\",\"type\":\"platform.initStart\",\"record\":{\"initializationType\":\"on-demand\",\"phase\":\"init\",\"runtimeVersion\":\"nodejs:20.v22\",\"runtimeVersionArn\":\"arn:aws:lambda:us-east-1::runtime:da57c20c4b965d5b75540f6865a35fc8030358e33ec44ecfed33e90901a27a72\",\"functionName\":\"hello-world\",\"functionVersion\":\"$LATEST\"}}]".to_vec());
392+
393+
let result = HttpRequestParser::from_stream(&stream);
394+
357395
assert!(result.is_ok());
358396
let parser = result.unwrap();
359397
assert_eq!(parser.headers.len(), 2);
360398
assert_eq!(
361399
parser.headers.get("content-length"),
362-
Some(&"13".to_string())
400+
Some(&"335".to_string())
363401
);
364402
assert_eq!(parser.headers.get("header1"), Some(&"Value1".to_string()));
365-
assert_eq!(parser.body, "Hello, World!".to_string());
403+
404+
assert_eq!(parser.body, "[{\"time\":\"2024-04-25T17:35:59.944Z\",\"type\":\"platform.initStart\",\"record\":{\"initializationType\":\"on-demand\",\"phase\":\"init\",\"runtimeVersion\":\"nodejs:20.v22\",\"runtimeVersionArn\":\"arn:aws:lambda:us-east-1::runtime:da57c20c4b965d5b75540f6865a35fc8030358e33ec44ecfed33e90901a27a72\",\"functionName\":\"hello-world\",\"functionVersion\":\"$LATEST\"}}]".to_string());
366405
}
367406
}

0 commit comments

Comments
 (0)