@@ -816,8 +816,11 @@ mod tests {
816
816
use http_body:: Body ;
817
817
use http_body_util:: { BodyExt , Empty , Full } ;
818
818
use hyper:: { body, body:: Bytes , client, service:: service_fn} ;
819
- use std:: { convert:: Infallible , error:: Error as StdError , net:: SocketAddr } ;
820
- use tokio:: net:: { TcpListener , TcpStream } ;
819
+ use std:: { convert:: Infallible , error:: Error as StdError , net:: SocketAddr , time:: Duration } ;
820
+ use tokio:: {
821
+ net:: { TcpListener , TcpStream } ,
822
+ pin,
823
+ } ;
821
824
822
825
const BODY : & [ u8 ] = b"Hello, world!" ;
823
826
@@ -871,6 +874,40 @@ mod tests {
871
874
assert_eq ! ( body, BODY ) ;
872
875
}
873
876
877
+ #[ cfg( not( miri) ) ]
878
+ #[ tokio:: test]
879
+ async fn graceful_shutdown ( ) {
880
+ let listener = TcpListener :: bind ( SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , 0 ) ) )
881
+ . await
882
+ . unwrap ( ) ;
883
+
884
+ let listener_addr = listener. local_addr ( ) . unwrap ( ) ;
885
+
886
+ // Spawn the task in background so that we can connect there
887
+ let listen_task = tokio:: spawn ( async move { listener. accept ( ) . await . unwrap ( ) } ) ;
888
+ // Only connect a stream, do not send headers or anything
889
+ let _stream = TcpStream :: connect ( listener_addr) . await . unwrap ( ) ;
890
+
891
+ let ( stream, _) = listen_task. await . unwrap ( ) ;
892
+ let stream = TokioIo :: new ( stream) ;
893
+ let builder = auto:: Builder :: new ( TokioExecutor :: new ( ) ) ;
894
+ let connection = builder. serve_connection ( stream, service_fn ( hello) ) ;
895
+
896
+ pin ! ( connection) ;
897
+
898
+ connection. as_mut ( ) . graceful_shutdown ( ) ;
899
+
900
+ let connection_error = tokio:: time:: timeout ( Duration :: from_millis ( 200 ) , connection)
901
+ . await
902
+ . expect ( "Connection should have finished in a timely manner after graceful shutdown." )
903
+ . expect_err ( "Connection should have been interrupted." ) ;
904
+
905
+ let connection_error = connection_error
906
+ . downcast_ref :: < std:: io:: Error > ( )
907
+ . expect ( "The error should have been `std::io::Error`." ) ;
908
+ assert_eq ! ( connection_error. kind( ) , std:: io:: ErrorKind :: Interrupted ) ;
909
+ }
910
+
874
911
async fn connect_h1 < B > ( addr : SocketAddr ) -> client:: conn:: http1:: SendRequest < B >
875
912
where
876
913
B : Body + Send + ' static ,
0 commit comments