Skip to content

Commit a421e95

Browse files
committed
Accept all available incoming connections.
If we don't have any available TLS connections, instead of returning Pending, we will try to get new (TCP) incoming connections again, unless there aren't any more to get. This mitigates a potential DoS attach by creating TCP connections without ever completing the TLS handshake, preventing additional requests from being accepted. Fixes: GHSA-2qph-qpvm-2qf7
1 parent 6c57dea commit a421e95

File tree

2 files changed

+51
-31
lines changed

2 files changed

+51
-31
lines changed

examples/http-change-certificate.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use hyper::service::service_fn;
33
use hyper::{body::Body, Request, Response};
44
use hyper_util::rt::tokio::TokioIo;
55
use std::convert::Infallible;
6+
use std::num::NonZeroUsize;
67
use std::sync::atomic::{AtomicU64, Ordering};
78
use std::sync::Arc;
89
use tokio::net::TcpListener;
@@ -22,7 +23,7 @@ async fn main() {
2223
let counter = Arc::new(AtomicU64::new(0));
2324

2425
let mut listener = tls_listener::builder(tls_acceptor())
25-
.max_handshakes(10)
26+
.accept_batch_size(NonZeroUsize::new(10).unwrap())
2627
.listen(TcpListener::bind(addr).await.expect("Failed to bind port"));
2728

2829
let (tx, mut rx) = mpsc::channel::<Acceptor>(1);

src/lib.rs

+49-30
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use pin_project_lite::pin_project;
1919
pub use spawning_handshake::SpawningHandshakes;
2020
use std::fmt::Debug;
2121
use std::future::{poll_fn, Future};
22+
use std::num::NonZeroUsize;
2223
use std::pin::Pin;
2324
use std::task::{ready, Context, Poll};
2425
use std::time::Duration;
@@ -38,8 +39,8 @@ mod spawning_handshake;
3839
#[cfg(feature = "tokio-net")]
3940
mod net;
4041

41-
/// Default number of concurrent handshakes
42-
pub const DEFAULT_MAX_HANDSHAKES: usize = 64;
42+
/// Default number of connections to accept in a batch before trying to
43+
pub const DEFAULT_ACCEPT_BATCH_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(64) };
4344
/// Default timeout for the TLS handshake.
4445
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
4546

@@ -112,7 +113,7 @@ pin_project! {
112113
listener: A,
113114
tls: T,
114115
waiting: FuturesUnordered<Waiting<A, T>>,
115-
max_handshakes: usize,
116+
accept_batch_size: NonZeroUsize,
116117
timeout: Duration,
117118
}
118119
}
@@ -121,7 +122,7 @@ pin_project! {
121122
#[derive(Clone)]
122123
pub struct Builder<T> {
123124
tls: T,
124-
max_handshakes: usize,
125+
accept_batch_size: NonZeroUsize,
125126
handshake_timeout: Duration,
126127
}
127128

@@ -182,26 +183,36 @@ where
182183
pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
183184
let mut this = self.project();
184185

185-
while this.waiting.len() < *this.max_handshakes {
186-
match this.listener.as_mut().poll_accept(cx) {
187-
Poll::Pending => break,
188-
Poll::Ready(Ok((conn, addr))) => {
189-
this.waiting.push(Waiting {
190-
inner: timeout(*this.timeout, this.tls.accept(conn)),
191-
peer_addr: Some(addr),
192-
});
193-
}
194-
Poll::Ready(Err(e)) => {
195-
return Poll::Ready(Err(Error::ListenerError(e)));
186+
loop {
187+
let mut empty_listener = false;
188+
for _ in 0..this.accept_batch_size.get() {
189+
match this.listener.as_mut().poll_accept(cx) {
190+
Poll::Pending => {
191+
empty_listener = true;
192+
break;
193+
}
194+
Poll::Ready(Ok((conn, addr))) => {
195+
this.waiting.push(Waiting {
196+
inner: timeout(*this.timeout, this.tls.accept(conn)),
197+
peer_addr: Some(addr),
198+
});
199+
}
200+
Poll::Ready(Err(e)) => {
201+
return Poll::Ready(Err(Error::ListenerError(e)));
202+
}
196203
}
197204
}
198-
}
199205

200-
match this.waiting.poll_next_unpin(cx) {
201-
Poll::Ready(Some(result)) => Poll::Ready(result),
202-
// If we don't have anything waiting yet,
203-
// then we are still pending,
204-
Poll::Ready(None) | Poll::Pending => Poll::Pending,
206+
match this.waiting.poll_next_unpin(cx) {
207+
Poll::Ready(Some(result)) => return Poll::Ready(result),
208+
// If we don't have anything waiting yet,
209+
// then we are still pending,
210+
Poll::Ready(None) | Poll::Pending => {
211+
if empty_listener {
212+
return Poll::Pending;
213+
}
214+
}
215+
}
205216
}
206217
}
207218

@@ -318,15 +329,19 @@ where
318329
}
319330

320331
impl<T> Builder<T> {
321-
/// Set the maximum number of concurrent handshakes.
332+
/// Set the size of batches of incoming connections to accept at once
333+
///
334+
/// When polling for a new connection, the `TlsListener` will first check
335+
/// for incomming connections on the listener that need to start a TLS handshake.
336+
/// This specifies the maximum number of connections it will accept before seeing if any
337+
/// TLS connections are ready.
322338
///
323-
/// At most `max` handshakes will be concurrently processed. If that limit is
324-
/// reached, the `TlsListener` will stop polling the underlying listener until a
325-
/// handshake completes and the encrypted stream has been returned.
339+
/// Having a limit for this ensures that ready TLS conections aren't starved if there are a
340+
/// large number of incoming connections.
326341
///
327-
/// Defaults to `DEFAULT_MAX_HANDSHAKES`.
328-
pub fn max_handshakes(&mut self, max: usize) -> &mut Self {
329-
self.max_handshakes = max;
342+
/// Defaults to `DEFAULT_ACCEPT_BATCH_SIZE`.
343+
pub fn accept_batch_size(&mut self, size: NonZeroUsize) -> &mut Self {
344+
self.accept_batch_size = size;
330345
self
331346
}
332347

@@ -335,6 +350,10 @@ impl<T> Builder<T> {
335350
/// If a timeout takes longer than `timeout`, then the handshake will be
336351
/// aborted and the underlying connection will be dropped.
337352
///
353+
/// The default is fairly conservative, to avoid dropping connections. It is
354+
/// recommended that you adjust this to meet the specific needs of your use case
355+
/// in production deployments.
356+
///
338357
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
339358
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
340359
self.handshake_timeout = timeout;
@@ -354,7 +373,7 @@ impl<T> Builder<T> {
354373
listener,
355374
tls: self.tls.clone(),
356375
waiting: FuturesUnordered::new(),
357-
max_handshakes: self.max_handshakes,
376+
accept_batch_size: self.accept_batch_size,
358377
timeout: self.handshake_timeout,
359378
}
360379
}
@@ -382,7 +401,7 @@ impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
382401
pub fn builder<T>(tls: T) -> Builder<T> {
383402
Builder {
384403
tls,
385-
max_handshakes: DEFAULT_MAX_HANDSHAKES,
404+
accept_batch_size: DEFAULT_ACCEPT_BATCH_SIZE,
386405
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
387406
}
388407
}

0 commit comments

Comments
 (0)