Skip to content

Commit 1141372

Browse files
authored
Implement Pluggable Name-resolution (fortanix#148)
This defines a new trait `Resolver`, which turns an address into a Vec<SocketAddr>. It also provides an implementation of Resolver for `Fn(&str)` so it's easy to define simple resolvers with a closure. Fixes fortanix#82 Co-authored-by: Ulrik <[email protected]>
1 parent 8bba07a commit 1141372

File tree

7 files changed

+144
-28
lines changed

7 files changed

+144
-28
lines changed

src/agent.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::sync::Mutex;
66
use crate::header::{self, Header};
77
use crate::pool::ConnectionPool;
88
use crate::request::Request;
9+
use crate::resolve::ArcResolver;
910

1011
/// Agents keep state between requests.
1112
///
@@ -53,15 +54,12 @@ pub(crate) struct AgentState {
5354
/// Cookies saved between requests.
5455
#[cfg(feature = "cookie")]
5556
pub(crate) jar: CookieJar,
57+
pub(crate) resolver: ArcResolver,
5658
}
5759

5860
impl AgentState {
5961
fn new() -> Self {
60-
AgentState {
61-
pool: ConnectionPool::new(),
62-
#[cfg(feature = "cookie")]
63-
jar: CookieJar::new(),
64-
}
62+
Self::default()
6563
}
6664
pub fn pool(&mut self) -> &mut ConnectionPool {
6765
&mut self.pool
@@ -194,6 +192,29 @@ impl Agent {
194192
.set_max_idle_connections_per_host(max_connections);
195193
}
196194

195+
/// Configures a custom resolver to be used by this agent. By default,
196+
/// address-resolution is done by std::net::ToSocketAddrs. This allows you
197+
/// to override that resolution with your own alternative. Useful for
198+
/// testing and special-cases like DNS-based load balancing.
199+
///
200+
/// A `Fn(&str) -> io::Result<Vec<SocketAddr>>` is a valid resolver,
201+
/// passing a closure is a simple way to override. Note that you might need
202+
/// explicit type `&str` on the closure argument for type inference to
203+
/// succeed.
204+
/// ```
205+
/// use std::net::ToSocketAddrs;
206+
///
207+
/// let mut agent = ureq::agent();
208+
/// agent.set_resolver(|addr: &str| match addr {
209+
/// "example.com" => Ok(vec![([127,0,0,1], 8096).into()]),
210+
/// addr => addr.to_socket_addrs().map(Iterator::collect),
211+
/// });
212+
/// ```
213+
pub fn set_resolver(&mut self, resolver: impl crate::Resolver + 'static) -> &mut Self {
214+
self.state.lock().unwrap().resolver = resolver.into();
215+
self
216+
}
217+
197218
/// Gets a cookie in this agent by name. Cookies are available
198219
/// either by setting it in the agent, or by making requests
199220
/// that `Set-Cookie` in the agent.

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ mod header;
125125
mod pool;
126126
mod proxy;
127127
mod request;
128+
mod resolve;
128129
mod response;
129130
mod stream;
130131
mod unit;
@@ -140,6 +141,7 @@ pub use crate::error::Error;
140141
pub use crate::header::Header;
141142
pub use crate::proxy::Proxy;
142143
pub use crate::request::Request;
144+
pub use crate::resolve::Resolver;
143145
pub use crate::response::Response;
144146

145147
// re-export

src/pool.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,6 @@ impl Default for ConnectionPool {
7474
}
7575

7676
impl ConnectionPool {
77-
pub fn new() -> Self {
78-
Self::default()
79-
}
80-
8177
pub fn set_max_idle_connections(&mut self, max_connections: usize) {
8278
if self.max_idle_connections == max_connections {
8379
return;
@@ -251,7 +247,7 @@ fn pool_connections_limit() {
251247
// Test inserting connections with different keys into the pool,
252248
// filling and draining it. The pool should evict earlier connections
253249
// when the connection limit is reached.
254-
let mut pool = ConnectionPool::new();
250+
let mut pool = ConnectionPool::default();
255251
let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i));
256252
let poolkeys = hostnames.map(|hostname| PoolKey {
257253
scheme: "https".to_string(),
@@ -276,7 +272,7 @@ fn pool_per_host_connections_limit() {
276272
// Test inserting connections with the same key into the pool,
277273
// filling and draining it. The pool should evict earlier connections
278274
// when the per-host connection limit is reached.
279-
let mut pool = ConnectionPool::new();
275+
let mut pool = ConnectionPool::default();
280276
let poolkey = PoolKey {
281277
scheme: "https".to_string(),
282278
hostname: "example.com".to_string(),
@@ -301,7 +297,7 @@ fn pool_per_host_connections_limit() {
301297

302298
#[test]
303299
fn pool_update_connection_limit() {
304-
let mut pool = ConnectionPool::new();
300+
let mut pool = ConnectionPool::default();
305301
pool.set_max_idle_connections(50);
306302

307303
let hostnames = (0..pool.max_idle_connections).map(|i| format!("{}.example", i));
@@ -321,7 +317,7 @@ fn pool_update_connection_limit() {
321317

322318
#[test]
323319
fn pool_update_per_host_connection_limit() {
324-
let mut pool = ConnectionPool::new();
320+
let mut pool = ConnectionPool::default();
325321
pool.set_max_idle_connections(50);
326322
pool.set_max_idle_connections_per_host(50);
327323

@@ -347,7 +343,7 @@ fn pool_update_per_host_connection_limit() {
347343
fn pool_checks_proxy() {
348344
// Test inserting different poolkeys with same address but different proxies.
349345
// Each insertion should result in an additional entry in the pool.
350-
let mut pool = ConnectionPool::new();
346+
let mut pool = ConnectionPool::default();
351347
let url = Url::parse("zzz:///example.com").unwrap();
352348

353349
pool.add(

src/resolve.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use std::fmt;
2+
use std::io::Result as IoResult;
3+
use std::net::{SocketAddr, ToSocketAddrs};
4+
use std::sync::Arc;
5+
6+
pub trait Resolver: Send + Sync {
7+
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>>;
8+
}
9+
10+
#[derive(Debug)]
11+
pub(crate) struct StdResolver;
12+
13+
impl Resolver for StdResolver {
14+
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
15+
ToSocketAddrs::to_socket_addrs(netloc).map(|iter| iter.collect())
16+
}
17+
}
18+
19+
impl<F> Resolver for F
20+
where
21+
F: Fn(&str) -> IoResult<Vec<SocketAddr>>,
22+
F: Send + Sync,
23+
{
24+
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
25+
self(netloc)
26+
}
27+
}
28+
29+
#[derive(Clone)]
30+
pub(crate) struct ArcResolver(Arc<dyn Resolver>);
31+
32+
impl<R> From<R> for ArcResolver
33+
where
34+
R: Resolver + 'static,
35+
{
36+
fn from(r: R) -> Self {
37+
Self(Arc::new(r))
38+
}
39+
}
40+
41+
impl fmt::Debug for ArcResolver {
42+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43+
write!(f, "ArcResolver(...)")
44+
}
45+
}
46+
47+
impl std::ops::Deref for ArcResolver {
48+
type Target = dyn Resolver;
49+
50+
fn deref(&self) -> &Self::Target {
51+
self.0.as_ref()
52+
}
53+
}
54+
55+
impl Default for ArcResolver {
56+
fn default() -> Self {
57+
StdResolver.into()
58+
}
59+
}

src/stream.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::io::{
44
};
55
use std::net::SocketAddr;
66
use std::net::TcpStream;
7-
use std::net::ToSocketAddrs;
87
use std::time::Duration;
98
use std::time::Instant;
109

@@ -386,15 +385,17 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
386385
} else {
387386
unit.deadline
388387
};
389-
390-
// TODO: Find a way to apply deadline to DNS lookup.
391-
let sock_addrs: Vec<SocketAddr> = match unit.req.proxy {
388+
389+
let netloc = match unit.req.proxy {
392390
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
393391
None => format!("{}:{}", hostname, port),
394-
}
395-
.to_socket_addrs()
396-
.map_err(|e| Error::DnsFailed(format!("{}", e)))?
397-
.collect();
392+
};
393+
394+
// TODO: Find a way to apply deadline to DNS lookup.
395+
let sock_addrs = unit
396+
.resolver()
397+
.resolve(&netloc)
398+
.map_err(|e| Error::DnsFailed(format!("{}", e)))?;
398399

399400
if sock_addrs.is_empty() {
400401
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
@@ -419,6 +420,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
419420
// connect with a configured timeout.
420421
let stream = if Some(Proto::SOCKS5) == proto {
421422
connect_socks5(
423+
&unit,
422424
unit.req.proxy.to_owned().unwrap(),
423425
deadline,
424426
sock_addr,
@@ -496,11 +498,15 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
496498
}
497499

498500
#[cfg(feature = "socks-proxy")]
499-
fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::io::Error> {
500-
let addrs: Vec<SocketAddr> = format!("{}:{}", hostname, port)
501-
.to_socket_addrs()
502-
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?
503-
.collect();
501+
fn socks5_local_nslookup(
502+
unit: &Unit,
503+
hostname: &str,
504+
port: u16,
505+
) -> Result<TargetAddr, std::io::Error> {
506+
let addrs: Vec<SocketAddr> = unit
507+
.resolver()
508+
.resolve(&format!("{}:{}", hostname, port))
509+
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?;
504510

505511
if addrs.is_empty() {
506512
return Err(std::io::Error::new(
@@ -522,6 +528,7 @@ fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::i
522528

523529
#[cfg(feature = "socks-proxy")]
524530
fn connect_socks5(
531+
unit: &Unit,
525532
proxy: Proxy,
526533
deadline: Option<Instant>,
527534
proxy_addr: SocketAddr,
@@ -533,7 +540,7 @@ fn connect_socks5(
533540
use std::str::FromStr;
534541

535542
let host_addr = if Ipv4Addr::from_str(host).is_ok() || Ipv6Addr::from_str(host).is_ok() {
536-
match socks5_local_nslookup(host, port) {
543+
match socks5_local_nslookup(unit, host, port) {
537544
Ok(addr) => addr,
538545
Err(err) => return Err(err),
539546
}
@@ -625,6 +632,7 @@ fn get_socks5_stream(
625632

626633
#[cfg(not(feature = "socks-proxy"))]
627634
fn connect_socks5(
635+
_unit: &Unit,
628636
_proxy: Proxy,
629637
_deadline: Option<Instant>,
630638
_proxy_addr: SocketAddr,

src/test/agent_test.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,31 @@ fn connection_reuse() {
101101
assert_eq!(resp.status(), 200);
102102
}
103103

104+
#[test]
105+
fn custom_resolver() {
106+
use std::io::Read;
107+
use std::net::TcpListener;
108+
109+
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
110+
111+
let local_addr = listener.local_addr().unwrap();
112+
113+
let server = std::thread::spawn(move || {
114+
let (mut client, _) = listener.accept().unwrap();
115+
let mut buf = vec![0u8; 16];
116+
let read = client.read(&mut buf).unwrap();
117+
buf.truncate(read);
118+
buf
119+
});
120+
121+
crate::agent()
122+
.set_resolver(move |_: &str| Ok(vec![local_addr]))
123+
.get("http://cool.server/")
124+
.call();
125+
126+
assert_eq!(&server.join().unwrap(), b"GET / HTTP/1.1\r\n");
127+
}
128+
104129
#[cfg(feature = "cookie")]
105130
#[cfg(test)]
106131
fn cookie_and_redirect(mut stream: TcpStream) -> io::Result<()> {

src/unit.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use cookie::{Cookie, CookieJar};
1010
use crate::agent::AgentState;
1111
use crate::body::{self, Payload, SizedReader};
1212
use crate::header;
13+
use crate::resolve::ArcResolver;
1314
use crate::stream::{self, connect_test, Stream};
1415
use crate::{Error, Header, Request, Response};
1516

@@ -95,6 +96,10 @@ impl Unit {
9596
self.req.method.eq_ignore_ascii_case("head")
9697
}
9798

99+
pub fn resolver(&self) -> ArcResolver {
100+
self.req.agent.lock().unwrap().resolver.clone()
101+
}
102+
98103
#[cfg(test)]
99104
pub fn header(&self, name: &str) -> Option<&str> {
100105
header::get_header(&self.headers, name)

0 commit comments

Comments
 (0)