diff --git a/Cargo.toml b/Cargo.toml index 355b778..6536da2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ tokio-openssl = { version = "0.6.3", optional = true } openssl_impl = { package = "openssl", version = "0.10.32", optional = true } [dev-dependencies] +axum = "0.8.1" hyper = { version = "1.0", features = ["http1", "server"] } hyper-util = { version = "0.1.1", features = ["tokio"] } tokio = { version = "1.0", features = [ diff --git a/examples/axum.rs b/examples/axum.rs new file mode 100644 index 0000000..e00ad8e --- /dev/null +++ b/examples/axum.rs @@ -0,0 +1,68 @@ +use axum::{routing::get, Router}; +use std::{io, net::SocketAddr, time::Duration}; +use tls_listener::TlsListener; +use tokio::net::{TcpListener, TcpStream}; + +mod tls_config; +use tls_config::tls_acceptor; + +/// An example of running an axum server with `TlsListener`. +/// +/// One can also bypass `axum::serve` and use the `Router` with Hyper's `serve_connection` API +/// directly. The main advantages of using `axum::serve` are that +/// - graceful shutdown is made easy with axum's `.with_graceful_shutdown` API, and +/// - the Hyper server is configured by axum itself, allowing options specific to axum to be set +/// (for example, axum currently enables the `CONNECT` protocol in order to support HTTP/2 +/// websockets). +#[tokio::main(flavor = "current_thread")] +async fn main() { + let app = Router::new().route("/", get(|| async { "Hello, World!" })); + + let local_addr = "0.0.0.0:3000".parse::().unwrap(); + let tcp_listener = tokio::net::TcpListener::bind(local_addr).await.unwrap(); + let listener = Listener { + inner: TlsListener::new(tls_acceptor(), tcp_listener), + local_addr, + }; + + axum::serve(listener, app).await.unwrap(); +} + +// We use a wrapper type to bridge axum's `Listener` trait to our `TlsListener` type. +struct Listener { + inner: TlsListener, + local_addr: SocketAddr, +} + +impl axum::serve::Listener for Listener { + type Io = tls_config::Stream; + type Addr = SocketAddr; + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + // To change the TLS certificate dynamically, you could `select!` on this call with a + // channel receiver, and call `self.inner.replace_acceptor` in the other branch. + match self.inner.accept().await { + Ok(tuple) => break tuple, + Err(tls_listener::Error::ListenerError(e)) if !is_connection_error(&e) => { + // See https://github.com/tokio-rs/axum/blob/da3539cb0e5eed381361b2e688a776da77c52cd6/axum/src/serve/listener.rs#L145-L157 + // for the rationale. + tokio::time::sleep(Duration::from_secs(1)).await + } + Err(_) => continue, + } + } + } + fn local_addr(&self) -> io::Result { + Ok(self.local_addr) + } +} + +// Taken from https://github.com/tokio-rs/axum/blob/da3539cb0e5eed381361b2e688a776da77c52cd6/axum/src/serve/listener.rs#L160-L167 +fn is_connection_error(e: &io::Error) -> bool { + matches!( + e.kind(), + io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset + ) +} diff --git a/examples/test_examples.py b/examples/test_examples.py index 897cfee..1f9772b 100755 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -133,3 +133,7 @@ def test_http_stream(self): def test_http_plain(self): with run_example("http"): self.http_test() + + def test_axum(self): + with run_example("axum"): + self.http_test() diff --git a/examples/tls_config/mod.rs b/examples/tls_config/mod.rs index 88ea557..3516613 100644 --- a/examples/tls_config/mod.rs +++ b/examples/tls_config/mod.rs @@ -15,6 +15,9 @@ mod config { pub type Acceptor = tokio_rustls::TlsAcceptor; + #[allow(dead_code)] + pub type Stream = tokio_rustls::server::TlsStream; + fn tls_acceptor_impl(key_der: &[u8], cert_der: &[u8]) -> Acceptor { let key = PrivateKeyDer::Pkcs1(key_der.to_owned().into()); let cert = CertificateDer::from(cert_der).into_owned(); @@ -49,6 +52,9 @@ mod config { pub type Acceptor = tokio_native_tls::TlsAcceptor; + #[allow(dead_code)] + pub type Stream = tokio_native_tls::TlsStream; + fn tls_acceptor_impl(pfx: &[u8]) -> Acceptor { let identity = Identity::from_pkcs12(pfx, "").unwrap(); TlsAcceptor::builder(identity).build().unwrap().into() @@ -73,6 +79,9 @@ mod config { pub type Acceptor = openssl_impl::ssl::SslContext; + #[allow(dead_code)] + pub type Stream = tokio_openssl::SslStream; + fn tls_acceptor_impl>(cert_file: P, key_file: P) -> Acceptor { let mut builder = SslContext::builder(SslMethod::tls_server()).unwrap(); builder