@@ -19,6 +19,7 @@ use pin_project_lite::pin_project;
19
19
pub use spawning_handshake:: SpawningHandshakes ;
20
20
use std:: fmt:: Debug ;
21
21
use std:: future:: { poll_fn, Future } ;
22
+ use std:: num:: NonZeroUsize ;
22
23
use std:: pin:: Pin ;
23
24
use std:: task:: { ready, Context , Poll } ;
24
25
use std:: time:: Duration ;
@@ -38,8 +39,8 @@ mod spawning_handshake;
38
39
#[ cfg( feature = "tokio-net" ) ]
39
40
mod net;
40
41
41
- /// Default number of concurrent handshakes
42
- pub const DEFAULT_MAX_HANDSHAKES : usize = 64 ;
42
+ /// Default number of connections to accept in a batch before trying to
43
+ pub const DEFAULT_ACCEPT_BATCH_SIZE : NonZeroUsize = unsafe { NonZeroUsize :: new_unchecked ( 64 ) } ;
43
44
/// Default timeout for the TLS handshake.
44
45
pub const DEFAULT_HANDSHAKE_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
45
46
@@ -112,7 +113,7 @@ pin_project! {
112
113
listener: A ,
113
114
tls: T ,
114
115
waiting: FuturesUnordered <Waiting <A , T >>,
115
- max_handshakes : usize ,
116
+ accept_batch_size : NonZeroUsize ,
116
117
timeout: Duration ,
117
118
}
118
119
}
@@ -121,7 +122,7 @@ pin_project! {
121
122
#[ derive( Clone ) ]
122
123
pub struct Builder < T > {
123
124
tls : T ,
124
- max_handshakes : usize ,
125
+ accept_batch_size : NonZeroUsize ,
125
126
handshake_timeout : Duration ,
126
127
}
127
128
@@ -182,26 +183,36 @@ where
182
183
pub fn poll_accept ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < <Self as Stream >:: Item > {
183
184
let mut this = self . project ( ) ;
184
185
185
- while this. waiting . len ( ) < * this. max_handshakes {
186
- match this. listener . as_mut ( ) . poll_accept ( cx) {
187
- Poll :: Pending => break ,
188
- Poll :: Ready ( Ok ( ( conn, addr) ) ) => {
189
- this. waiting . push ( Waiting {
190
- inner : timeout ( * this. timeout , this. tls . accept ( conn) ) ,
191
- peer_addr : Some ( addr) ,
192
- } ) ;
193
- }
194
- Poll :: Ready ( Err ( e) ) => {
195
- return Poll :: Ready ( Err ( Error :: ListenerError ( e) ) ) ;
186
+ loop {
187
+ let mut empty_listener = false ;
188
+ for _ in 0 ..this. accept_batch_size . get ( ) {
189
+ match this. listener . as_mut ( ) . poll_accept ( cx) {
190
+ Poll :: Pending => {
191
+ empty_listener = true ;
192
+ break ;
193
+ }
194
+ Poll :: Ready ( Ok ( ( conn, addr) ) ) => {
195
+ this. waiting . push ( Waiting {
196
+ inner : timeout ( * this. timeout , this. tls . accept ( conn) ) ,
197
+ peer_addr : Some ( addr) ,
198
+ } ) ;
199
+ }
200
+ Poll :: Ready ( Err ( e) ) => {
201
+ return Poll :: Ready ( Err ( Error :: ListenerError ( e) ) ) ;
202
+ }
196
203
}
197
204
}
198
- }
199
205
200
- match this. waiting . poll_next_unpin ( cx) {
201
- Poll :: Ready ( Some ( result) ) => Poll :: Ready ( result) ,
202
- // If we don't have anything waiting yet,
203
- // then we are still pending,
204
- Poll :: Ready ( None ) | Poll :: Pending => Poll :: Pending ,
206
+ match this. waiting . poll_next_unpin ( cx) {
207
+ Poll :: Ready ( Some ( result) ) => return Poll :: Ready ( result) ,
208
+ // If we don't have anything waiting yet,
209
+ // then we are still pending,
210
+ Poll :: Ready ( None ) | Poll :: Pending => {
211
+ if empty_listener {
212
+ return Poll :: Pending ;
213
+ }
214
+ }
215
+ }
205
216
}
206
217
}
207
218
@@ -318,15 +329,19 @@ where
318
329
}
319
330
320
331
impl < T > Builder < T > {
321
- /// Set the maximum number of concurrent handshakes.
332
+ /// Set the size of batches of incoming connections to accept at once
333
+ ///
334
+ /// When polling for a new connection, the `TlsListener` will first check
335
+ /// for incomming connections on the listener that need to start a TLS handshake.
336
+ /// This specifies the maximum number of connections it will accept before seeing if any
337
+ /// TLS connections are ready.
322
338
///
323
- /// At most `max` handshakes will be concurrently processed. If that limit is
324
- /// reached, the `TlsListener` will stop polling the underlying listener until a
325
- /// handshake completes and the encrypted stream has been returned.
339
+ /// Having a limit for this ensures that ready TLS conections aren't starved if there are a
340
+ /// large number of incoming connections.
326
341
///
327
- /// Defaults to `DEFAULT_MAX_HANDSHAKES `.
328
- pub fn max_handshakes ( & mut self , max : usize ) -> & mut Self {
329
- self . max_handshakes = max ;
342
+ /// Defaults to `DEFAULT_ACCEPT_BATCH_SIZE `.
343
+ pub fn accept_batch_size ( & mut self , size : NonZeroUsize ) -> & mut Self {
344
+ self . accept_batch_size = size ;
330
345
self
331
346
}
332
347
@@ -335,6 +350,10 @@ impl<T> Builder<T> {
335
350
/// If a timeout takes longer than `timeout`, then the handshake will be
336
351
/// aborted and the underlying connection will be dropped.
337
352
///
353
+ /// The default is fairly conservative, to avoid dropping connections. It is
354
+ /// recommended that you adjust this to meet the specific needs of your use case
355
+ /// in production deployments.
356
+ ///
338
357
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
339
358
pub fn handshake_timeout ( & mut self , timeout : Duration ) -> & mut Self {
340
359
self . handshake_timeout = timeout;
@@ -354,7 +373,7 @@ impl<T> Builder<T> {
354
373
listener,
355
374
tls : self . tls . clone ( ) ,
356
375
waiting : FuturesUnordered :: new ( ) ,
357
- max_handshakes : self . max_handshakes ,
376
+ accept_batch_size : self . accept_batch_size ,
358
377
timeout : self . handshake_timeout ,
359
378
}
360
379
}
@@ -382,7 +401,7 @@ impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
382
401
pub fn builder < T > ( tls : T ) -> Builder < T > {
383
402
Builder {
384
403
tls,
385
- max_handshakes : DEFAULT_MAX_HANDSHAKES ,
404
+ accept_batch_size : DEFAULT_ACCEPT_BATCH_SIZE ,
386
405
handshake_timeout : DEFAULT_HANDSHAKE_TIMEOUT ,
387
406
}
388
407
}
0 commit comments