Skip to content

Commit 2ffe1ef

Browse files
committed
Add an integration test
1 parent 17a8f15 commit 2ffe1ef

34 files changed

+6498
-0
lines changed

tests/common/mod.rs

Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
//! Utility code to help writing triagebot tests.
2+
3+
use std::collections::HashMap;
4+
use std::io::{BufRead, BufReader, Read, Write};
5+
use std::net::TcpStream;
6+
use std::net::{SocketAddr, TcpListener};
7+
use std::sync::{Arc, Mutex};
8+
use url::Url;
9+
10+
/// The callback type for HTTP route handlers.
11+
pub type RequestCallback = Box<dyn Send + Fn(Request) -> Response>;
12+
13+
/// HTTP method.
14+
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
15+
pub enum Method {
16+
GET,
17+
POST,
18+
PUT,
19+
DELETE,
20+
PATCH,
21+
}
22+
23+
impl Method {
24+
fn from_str(s: &str) -> Method {
25+
match s {
26+
"GET" => Method::GET,
27+
"POST" => Method::POST,
28+
"PUT" => Method::PUT,
29+
"DELETE" => Method::DELETE,
30+
"PATCH" => Method::PATCH,
31+
_ => panic!("unexpected HTTP method {s}"),
32+
}
33+
}
34+
}
35+
36+
/// A builder for preparing a test.
37+
#[derive(Default)]
38+
pub struct TestBuilder {
39+
pub config: Option<&'static str>,
40+
pub api_handlers: HashMap<(Method, &'static str), RequestCallback>,
41+
pub raw_handlers: HashMap<(Method, &'static str), RequestCallback>,
42+
}
43+
44+
/// A request received on the HTTP server.
45+
#[derive(Clone, Debug)]
46+
pub struct Request {
47+
/// The path of the request, such as `repos/rust-lang/rust/labels`.
48+
pub path: String,
49+
/// The HTTP method.
50+
pub method: Method,
51+
/// Components in the path that were captured with the `{foo}` syntax.
52+
/// See [`TestBuilder::api_handler`] for details.
53+
pub components: HashMap<String, String>,
54+
/// The query components of the URL (the stuff after `?`).
55+
pub query: Vec<(String, String)>,
56+
/// HTTP headers.
57+
pub headers: HashMap<String, String>,
58+
/// The body of the HTTP request (usually a JSON blob).
59+
pub body: Vec<u8>,
60+
}
61+
62+
impl Request {
63+
pub fn json(&self) -> serde_json::Value {
64+
serde_json::from_slice(&self.body).unwrap()
65+
}
66+
pub fn body_str(&self) -> String {
67+
String::from_utf8(self.body.clone()).unwrap()
68+
}
69+
70+
pub fn query_string(&self) -> String {
71+
let vs: Vec<_> = self.query.iter().map(|(k, v)| format!("{k}={v}")).collect();
72+
vs.join("&")
73+
}
74+
}
75+
76+
/// The response the HTTP server should send to the client.
77+
pub struct Response {
78+
pub code: u32,
79+
pub headers: Vec<String>,
80+
pub body: Vec<u8>,
81+
}
82+
83+
impl Response {
84+
pub fn new() -> Response {
85+
Response {
86+
code: 200,
87+
headers: Vec::new(),
88+
body: Vec::new(),
89+
}
90+
}
91+
92+
pub fn new_from_path(path: &str) -> Response {
93+
Response {
94+
code: 200,
95+
headers: Vec::new(),
96+
body: std::fs::read(path).unwrap(),
97+
}
98+
}
99+
100+
pub fn body(mut self, file: &[u8]) -> Self {
101+
self.body = Vec::from(file);
102+
self
103+
}
104+
}
105+
106+
/// A recording of HTTP requests which can then be validated they are
107+
/// performed in the correct order.
108+
///
109+
/// A copy of this is shared among the different HTTP servers. At the end of
110+
/// the test, the test should call `assert_eq` to validate the correct actions
111+
/// were performed.
112+
#[derive(Clone)]
113+
pub struct Events(Arc<Mutex<Vec<(Method, String)>>>);
114+
115+
impl Events {
116+
pub fn new() -> Events {
117+
Events(Arc::new(Mutex::new(Vec::new())))
118+
}
119+
120+
fn push(&self, method: Method, path: String) {
121+
let mut es = self.0.lock().unwrap();
122+
es.push((method, path));
123+
}
124+
125+
pub fn assert_eq(&self, expected: &[(Method, &str)]) {
126+
let es = self.0.lock().unwrap();
127+
for (actual, expected) in es.iter().zip(expected.iter()) {
128+
if actual.0 != expected.0 || actual.1 != expected.1 {
129+
panic!("expected request to {expected:?}, but next event was {actual:?}");
130+
}
131+
}
132+
if es.len() > expected.len() {
133+
panic!(
134+
"got unexpected extra requests, \
135+
make sure the event assertion lists all events\n\
136+
Extras are: {:?} ",
137+
&es[expected.len()..]
138+
);
139+
} else if es.len() < expected.len() {
140+
panic!(
141+
"expected additional requests that were never made, \
142+
make sure the event assertion lists the correct requests\n\
143+
Extra expected are: {:?}",
144+
&expected[es.len()..]
145+
);
146+
}
147+
}
148+
}
149+
150+
/// A primitive HTTP server.
151+
pub struct HttpServer {
152+
listener: TcpListener,
153+
/// Handlers to call for specific routes.
154+
handlers: HashMap<(Method, &'static str), RequestCallback>,
155+
/// A recording of all API requests.
156+
events: Events,
157+
}
158+
159+
/// A reference on how to connect to the test HTTP server.
160+
pub struct HttpServerHandle {
161+
pub addr: SocketAddr,
162+
}
163+
164+
impl Drop for HttpServerHandle {
165+
fn drop(&mut self) {
166+
if let Ok(mut stream) = TcpStream::connect(self.addr) {
167+
// shut down the server
168+
let _ = stream.write_all(b"STOP");
169+
let _ = stream.flush();
170+
}
171+
}
172+
}
173+
174+
impl TestBuilder {
175+
/// Sets the config for the `triagebot.toml` file for the `rust-lang/rust`
176+
/// repository.
177+
pub fn config(mut self, config: &'static str) -> Self {
178+
self.config = Some(config);
179+
self
180+
}
181+
182+
/// Adds an HTTP handler for https://api.github.com/
183+
///
184+
/// The `path` is the route, like `repos/rust-lang/rust/labels`. A generic
185+
/// route can be configured using curly braces. For example, to get all
186+
/// requests for labels, use a path like `repos/rust-lang/rust/{label}`.
187+
/// The value of the path component can be found in
188+
/// [`Request::components`].
189+
///
190+
/// If the path ends with `{...}`, then this means "the rest of the path".
191+
/// The rest of the path can be obtained from the `...` value in the
192+
/// `Request::components` map.
193+
pub fn api_handler<R: 'static + Send + Fn(Request) -> Response>(
194+
mut self,
195+
method: Method,
196+
path: &'static str,
197+
responder: R,
198+
) -> Self {
199+
self.api_handlers
200+
.insert((method, path), Box::new(responder));
201+
self
202+
}
203+
204+
/// Adds an HTTP handler for https://raw.githubusercontent.com
205+
pub fn raw_handler<R: 'static + Send + Fn(Request) -> Response>(
206+
mut self,
207+
method: Method,
208+
path: &'static str,
209+
responder: R,
210+
) -> Self {
211+
self.raw_handlers
212+
.insert((method, path), Box::new(responder));
213+
self
214+
}
215+
216+
/// Enables logging if `TRIAGEBOT_TEST_LOG` is set. This can help with
217+
/// debugging a test.
218+
pub fn maybe_enable_logging(&self) {
219+
const LOG_VAR: &str = "TRIAGEBOT_TEST_LOG";
220+
use std::sync::Once;
221+
static DO_INIT: Once = Once::new();
222+
if std::env::var_os(LOG_VAR).is_some() {
223+
DO_INIT.call_once(|| {
224+
dotenv::dotenv().ok();
225+
tracing_subscriber::fmt::Subscriber::builder()
226+
.with_env_filter(tracing_subscriber::EnvFilter::from_env(LOG_VAR))
227+
.with_ansi(std::env::var_os("DISABLE_COLOR").is_none())
228+
.try_init()
229+
.unwrap();
230+
});
231+
}
232+
}
233+
}
234+
235+
impl HttpServer {
236+
pub fn new(
237+
handlers: HashMap<(Method, &'static str), RequestCallback>,
238+
events: Events,
239+
) -> HttpServerHandle {
240+
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
241+
let addr = listener.local_addr().unwrap();
242+
let server = HttpServer {
243+
listener,
244+
handlers,
245+
events,
246+
};
247+
std::thread::spawn(move || server.start());
248+
HttpServerHandle { addr }
249+
}
250+
251+
fn start(&self) {
252+
let mut line = String::new();
253+
'server: loop {
254+
let (socket, _) = self.listener.accept().unwrap();
255+
let mut buf = BufReader::new(socket);
256+
line.clear();
257+
if buf.read_line(&mut line).unwrap() == 0 {
258+
// Connection terminated.
259+
eprintln!("unexpected client drop");
260+
continue;
261+
}
262+
// Read the "GET path HTTP/1.1" line.
263+
let mut parts = line.split_ascii_whitespace();
264+
let method = parts.next().unwrap().to_ascii_uppercase();
265+
if method == "STOP" {
266+
// Shutdown the server.
267+
return;
268+
}
269+
let path = parts.next().unwrap();
270+
// The host here doesn't matter, we're just interested in parsing
271+
// the query string.
272+
let url = Url::parse(&format!("https://api.github.com{path}")).unwrap();
273+
274+
let mut headers = HashMap::new();
275+
let mut content_len = None;
276+
loop {
277+
line.clear();
278+
if buf.read_line(&mut line).unwrap() == 0 {
279+
continue 'server;
280+
}
281+
if line == "\r\n" {
282+
// End of headers.
283+
line.clear();
284+
break;
285+
}
286+
let (name, value) = line.split_once(':').unwrap();
287+
let name = name.trim().to_ascii_lowercase();
288+
let value = value.trim().to_string();
289+
match name.as_str() {
290+
"content-length" => content_len = Some(value.parse::<u64>().unwrap()),
291+
_ => {}
292+
}
293+
headers.insert(name, value);
294+
}
295+
let mut body = vec![0u8; content_len.unwrap_or(0) as usize];
296+
buf.read_exact(&mut body).unwrap();
297+
298+
let method = Method::from_str(&method);
299+
self.events.push(method, url.path().to_string());
300+
let response = self.route(method, &url, headers, body);
301+
302+
let buf = buf.get_mut();
303+
write!(buf, "HTTP/1.1 {}\r\n", response.code).unwrap();
304+
write!(buf, "Content-Length: {}\r\n", response.body.len()).unwrap();
305+
write!(buf, "Connection: close\r\n").unwrap();
306+
for header in response.headers {
307+
write!(buf, "{}\r\n", header).unwrap();
308+
}
309+
write!(buf, "\r\n").unwrap();
310+
buf.write_all(&response.body).unwrap();
311+
buf.flush().unwrap();
312+
}
313+
}
314+
315+
/// Route the request
316+
fn route(
317+
&self,
318+
method: Method,
319+
url: &Url,
320+
headers: HashMap<String, String>,
321+
body: Vec<u8>,
322+
) -> Response {
323+
eprintln!("route {method:?} {url}",);
324+
let query = url
325+
.query_pairs()
326+
.map(|(k, v)| (k.to_string(), v.to_string()))
327+
.collect();
328+
let segments: Vec<_> = url.path_segments().unwrap().collect();
329+
let path = url.path().to_string();
330+
for ((route_method, route_pattern), responder) in &self.handlers {
331+
if *route_method != method {
332+
continue;
333+
}
334+
if let Some(components) = match_route(route_pattern, &segments) {
335+
let request = Request {
336+
method,
337+
path,
338+
query,
339+
components,
340+
headers,
341+
body,
342+
};
343+
tracing::debug!("request={request:?}");
344+
return responder(request);
345+
}
346+
}
347+
eprintln!(
348+
"route {method:?} {url} has no handler.\n\
349+
Add a handler to the context for this route."
350+
);
351+
Response {
352+
code: 404,
353+
headers: Vec::new(),
354+
body: b"404 not found".to_vec(),
355+
}
356+
}
357+
}
358+
359+
fn match_route(route_pattern: &str, segments: &[&str]) -> Option<HashMap<String, String>> {
360+
let mut segments = segments.into_iter();
361+
let mut components = HashMap::new();
362+
for part in route_pattern.split('/') {
363+
if part == "{...}" {
364+
let rest: Vec<_> = segments.map(|s| *s).collect();
365+
components.insert("...".to_string(), rest.join("/"));
366+
return Some(components);
367+
}
368+
match segments.next() {
369+
None => return None,
370+
Some(actual) => {
371+
if part.starts_with('{') {
372+
let part = part[1..part.len() - 1].to_string();
373+
components.insert(part, actual.to_string());
374+
} else if *actual != part {
375+
return None;
376+
}
377+
}
378+
}
379+
}
380+
if segments.next().is_some() {
381+
return None;
382+
}
383+
Some(components)
384+
}

0 commit comments

Comments
 (0)