Skip to content

Commit d5a7655

Browse files
authored
Merge pull request from GHSA-2qph-qpvm-2qf7
Continue accepting incoming connections if no TLS connection is ready.
2 parents 6c57dea + d1769ec commit d5a7655

File tree

3 files changed

+52
-32
lines changed

3 files changed

+52
-32
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "tls-listener"
33
description = "wrap incoming Stream of connections in TLS"
4-
version = "0.9.1"
4+
version = "0.10.0"
55
authors = ["Thayne McCombs <[email protected]>"]
66
repository = "https://github.com/tmccombs/tls-listener"
77
edition = "2018"

examples/http-change-certificate.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use hyper::service::service_fn;
33
use hyper::{body::Body, Request, Response};
44
use hyper_util::rt::tokio::TokioIo;
55
use std::convert::Infallible;
6+
use std::num::NonZeroUsize;
67
use std::sync::atomic::{AtomicU64, Ordering};
78
use std::sync::Arc;
89
use tokio::net::TcpListener;
@@ -22,7 +23,7 @@ async fn main() {
2223
let counter = Arc::new(AtomicU64::new(0));
2324

2425
let mut listener = tls_listener::builder(tls_acceptor())
25-
.max_handshakes(10)
26+
.accept_batch_size(NonZeroUsize::new(10).unwrap())
2627
.listen(TcpListener::bind(addr).await.expect("Failed to bind port"));
2728

2829
let (tx, mut rx) = mpsc::channel::<Acceptor>(1);

src/lib.rs

+49-30
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use pin_project_lite::pin_project;
1919
pub use spawning_handshake::SpawningHandshakes;
2020
use std::fmt::Debug;
2121
use std::future::{poll_fn, Future};
22+
use std::num::NonZeroUsize;
2223
use std::pin::Pin;
2324
use std::task::{ready, Context, Poll};
2425
use std::time::Duration;
@@ -38,8 +39,8 @@ mod spawning_handshake;
3839
#[cfg(feature = "tokio-net")]
3940
mod net;
4041

41-
/// Default number of concurrent handshakes
42-
pub const DEFAULT_MAX_HANDSHAKES: usize = 64;
42+
/// Default number of connections to accept in a batch before trying to
43+
pub const DEFAULT_ACCEPT_BATCH_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(64) };
4344
/// Default timeout for the TLS handshake.
4445
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
4546

@@ -112,7 +113,7 @@ pin_project! {
112113
listener: A,
113114
tls: T,
114115
waiting: FuturesUnordered<Waiting<A, T>>,
115-
max_handshakes: usize,
116+
accept_batch_size: NonZeroUsize,
116117
timeout: Duration,
117118
}
118119
}
@@ -121,7 +122,7 @@ pin_project! {
121122
#[derive(Clone)]
122123
pub struct Builder<T> {
123124
tls: T,
124-
max_handshakes: usize,
125+
accept_batch_size: NonZeroUsize,
125126
handshake_timeout: Duration,
126127
}
127128

@@ -182,26 +183,36 @@ where
182183
pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
183184
let mut this = self.project();
184185

185-
while this.waiting.len() < *this.max_handshakes {
186-
match this.listener.as_mut().poll_accept(cx) {
187-
Poll::Pending => break,
188-
Poll::Ready(Ok((conn, addr))) => {
189-
this.waiting.push(Waiting {
190-
inner: timeout(*this.timeout, this.tls.accept(conn)),
191-
peer_addr: Some(addr),
192-
});
193-
}
194-
Poll::Ready(Err(e)) => {
195-
return Poll::Ready(Err(Error::ListenerError(e)));
186+
loop {
187+
let mut empty_listener = false;
188+
for _ in 0..this.accept_batch_size.get() {
189+
match this.listener.as_mut().poll_accept(cx) {
190+
Poll::Pending => {
191+
empty_listener = true;
192+
break;
193+
}
194+
Poll::Ready(Ok((conn, addr))) => {
195+
this.waiting.push(Waiting {
196+
inner: timeout(*this.timeout, this.tls.accept(conn)),
197+
peer_addr: Some(addr),
198+
});
199+
}
200+
Poll::Ready(Err(e)) => {
201+
return Poll::Ready(Err(Error::ListenerError(e)));
202+
}
196203
}
197204
}
198-
}
199205

200-
match this.waiting.poll_next_unpin(cx) {
201-
Poll::Ready(Some(result)) => Poll::Ready(result),
202-
// If we don't have anything waiting yet,
203-
// then we are still pending,
204-
Poll::Ready(None) | Poll::Pending => Poll::Pending,
206+
match this.waiting.poll_next_unpin(cx) {
207+
Poll::Ready(Some(result)) => return Poll::Ready(result),
208+
// If we don't have anything waiting yet,
209+
// then we are still pending,
210+
Poll::Ready(None) | Poll::Pending => {
211+
if empty_listener {
212+
return Poll::Pending;
213+
}
214+
}
215+
}
205216
}
206217
}
207218

@@ -318,15 +329,19 @@ where
318329
}
319330

320331
impl<T> Builder<T> {
321-
/// Set the maximum number of concurrent handshakes.
332+
/// Set the size of batches of incoming connections to accept at once
333+
///
334+
/// When polling for a new connection, the `TlsListener` will first check
335+
/// for incomming connections on the listener that need to start a TLS handshake.
336+
/// This specifies the maximum number of connections it will accept before seeing if any
337+
/// TLS connections are ready.
322338
///
323-
/// At most `max` handshakes will be concurrently processed. If that limit is
324-
/// reached, the `TlsListener` will stop polling the underlying listener until a
325-
/// handshake completes and the encrypted stream has been returned.
339+
/// Having a limit for this ensures that ready TLS conections aren't starved if there are a
340+
/// large number of incoming connections.
326341
///
327-
/// Defaults to `DEFAULT_MAX_HANDSHAKES`.
328-
pub fn max_handshakes(&mut self, max: usize) -> &mut Self {
329-
self.max_handshakes = max;
342+
/// Defaults to `DEFAULT_ACCEPT_BATCH_SIZE`.
343+
pub fn accept_batch_size(&mut self, size: NonZeroUsize) -> &mut Self {
344+
self.accept_batch_size = size;
330345
self
331346
}
332347

@@ -335,6 +350,10 @@ impl<T> Builder<T> {
335350
/// If a timeout takes longer than `timeout`, then the handshake will be
336351
/// aborted and the underlying connection will be dropped.
337352
///
353+
/// The default is fairly conservative, to avoid dropping connections. It is
354+
/// recommended that you adjust this to meet the specific needs of your use case
355+
/// in production deployments.
356+
///
338357
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
339358
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
340359
self.handshake_timeout = timeout;
@@ -354,7 +373,7 @@ impl<T> Builder<T> {
354373
listener,
355374
tls: self.tls.clone(),
356375
waiting: FuturesUnordered::new(),
357-
max_handshakes: self.max_handshakes,
376+
accept_batch_size: self.accept_batch_size,
358377
timeout: self.handshake_timeout,
359378
}
360379
}
@@ -382,7 +401,7 @@ impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
382401
pub fn builder<T>(tls: T) -> Builder<T> {
383402
Builder {
384403
tls,
385-
max_handshakes: DEFAULT_MAX_HANDSHAKES,
404+
accept_batch_size: DEFAULT_ACCEPT_BATCH_SIZE,
386405
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
387406
}
388407
}

0 commit comments

Comments
 (0)