Skip to content

Commit a89fc3e

Browse files
authored
Default headers work around (#263)
1 parent 1a959a7 commit a89fc3e

File tree

1 file changed

+32
-13
lines changed

1 file changed

+32
-13
lines changed

crates/twirp/src/client.rs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ use std::sync::Arc;
33
use std::vec;
44

55
use async_trait::async_trait;
6+
use http::header::Entry;
7+
use http::header::IntoHeaderName;
8+
use http::HeaderMap;
9+
use http::HeaderValue;
610
use reqwest::header::CONTENT_TYPE;
711
use url::Host;
812
use url::Url;
@@ -43,21 +47,21 @@ impl ClientBuilder {
4347
}
4448
}
4549

50+
/// Set the HTTP client. Without this a default HTTP client is used.
51+
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
52+
self.http_client = Some(http_client);
53+
self
54+
}
55+
4656
/// Add middleware to the client that will be called on each request.
4757
/// Middlewares are invoked in the order they are added as part of the
4858
/// request cycle.
49-
pub fn with_middleware<M>(self, middleware: M) -> Self
59+
pub fn with_middleware<M>(mut self, middleware: M) -> Self
5060
where
5161
M: Middleware,
5262
{
53-
let mut mw = self.middleware;
54-
mw.push(Box::new(middleware));
55-
Self {
56-
base_url: self.base_url,
57-
http_client: self.http_client,
58-
handlers: self.handlers,
59-
middleware: mw,
60-
}
63+
self.middleware.push(Box::new(middleware));
64+
self
6165
}
6266

6367
/// Add a handler for a service using the default host.
@@ -83,9 +87,16 @@ impl ClientBuilder {
8387
self
8488
}
8589

86-
/// Set the HTTP client. Without this a default HTTP client is used.
87-
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
88-
self.http_client = Some(http_client);
90+
/// Set a default header for use in direct mode.
91+
pub fn with_default_header<K>(mut self, key: K, value: HeaderValue) -> Self
92+
where
93+
K: IntoHeaderName,
94+
{
95+
if let Some(handlers) = &mut self.handlers {
96+
handlers.default_headers.insert(key, value);
97+
} else {
98+
panic!("you must use `ClientBuilder::direct()` to register handler default headers");
99+
}
89100
self
90101
}
91102

@@ -315,9 +326,15 @@ impl<'a> Next<'a> {
315326
}
316327

317328
async fn execute_handlers(
318-
req: reqwest::Request,
329+
mut req: reqwest::Request,
319330
request_handlers: &RequestHandlers,
320331
) -> Result<reqwest::Response> {
332+
let req_headers = req.headers_mut();
333+
for (key, value) in &request_handlers.default_headers {
334+
if let Entry::Vacant(entry) = req_headers.entry(key) {
335+
entry.insert(value.clone());
336+
}
337+
}
321338
let url = req.url().clone();
322339
let Some(mut segments) = url.path_segments() else {
323340
return Err(crate::bad_route(format!(
@@ -344,13 +361,15 @@ async fn execute_handlers(
344361

345362
#[derive(Clone, Default)]
346363
pub struct RequestHandlers {
364+
default_headers: HeaderMap,
347365
/// A map of host/service names to handlers.
348366
handlers: HashMap<String, Arc<dyn DirectHandler>>,
349367
}
350368

351369
impl RequestHandlers {
352370
pub fn new() -> Self {
353371
Self {
372+
default_headers: HeaderMap::new(),
354373
handlers: HashMap::new(),
355374
}
356375
}

0 commit comments

Comments
 (0)