Skip to content

Commit 55f643c

Browse files
mfelscheJason Mobarak
and
Jason Mobarak
authored
feat: port to rustls 0.20
* attempt to port to rustls 0.20 * clippy * format * Fix test certificates (expired) and add a script to regenerate them. * Fix hanging and failing unit-tests UnexpectedEof errors are bubbled up from rustls. Tests needed to changed slightly, but are en par with tokio/tls. * Fix integration tests * Fix client and server examples * Update async-std version used for testing --------- Co-authored-by: Jason Mobarak <[email protected]>
1 parent a8ca3ca commit 55f643c

21 files changed

+442
-318
lines changed

Diff for: Cargo.toml

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
[package]
22
name = "async-tls"
33
version = "0.11.0"
4-
authors = ["The async-rs developers", "Florian Gilcher <[email protected]>", "dignifiedquire <[email protected]>", "quininer kel <[email protected]>"]
4+
authors = [
5+
"The async-rs developers",
6+
"Florian Gilcher <[email protected]>",
7+
"dignifiedquire <[email protected]>",
8+
"quininer kel <[email protected]>",
9+
]
510
license = "MIT/Apache-2.0"
611
repository = "https://github.com/async-std/async-tls"
712
homepage = "https://github.com/async-std/async-tls"
@@ -18,9 +23,10 @@ appveyor = { repository = "async-std/async-tls" }
1823
[dependencies]
1924
futures-io = "0.3.5"
2025
futures-core = "0.3.5"
21-
rustls = "0.19.0"
22-
webpki = { version = "0.21.3", optional = true }
23-
webpki-roots = { version = "0.21.0", optional = true }
26+
rustls = "0.20.6"
27+
rustls-pemfile = "1.0"
28+
webpki = { version = "0.22.0", optional = true }
29+
webpki-roots = { version = "0.22.3", optional = true }
2430

2531
[features]
2632
default = ["client", "server"]
@@ -32,7 +38,7 @@ server = []
3238
lazy_static = "1"
3339
futures-executor = "0.3.5"
3440
futures-util = { version = "0.3.5", features = ["io"] }
35-
async-std = { version = "1.0", features = ["unstable"] }
41+
async-std = { version = "1.11", features = ["unstable"] }
3642

3743
[[test]]
3844
name = "test"

Diff for: examples/client/Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ edition = "2018"
66

77
[dependencies]
88
structopt = "0.3.9"
9-
rustls = "0.19.0"
10-
async-std = "1.5.0"
9+
rustls = "0.20.6"
10+
rustls-pemfile = "1.0"
11+
async-std = "1.11.0"
1112
async-tls = { path = "../.." }

Diff for: examples/client/src/main.rs

+11-8
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use async_std::task;
55
use async_tls::TlsConnector;
66

77
use rustls::ClientConfig;
8+
use rustls_pemfile::certs;
89

9-
use std::io::Cursor;
10+
use std::io::{BufReader, Cursor};
1011
use std::net::ToSocketAddrs;
1112
use std::path::{Path, PathBuf};
1213
use std::sync::Arc;
@@ -81,12 +82,14 @@ fn main() -> io::Result<()> {
8182
}
8283

8384
async fn connector_for_ca_file(cafile: &Path) -> io::Result<TlsConnector> {
84-
let mut config = ClientConfig::new();
85-
let file = async_std::fs::read(cafile).await?;
86-
let mut pem = Cursor::new(file);
87-
config
88-
.root_store
89-
.add_pem_file(&mut pem)
90-
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?;
85+
let mut root_store = rustls::RootCertStore::empty();
86+
let ca_bytes = async_std::fs::read(cafile).await?;
87+
let cert = certs(&mut BufReader::new(Cursor::new(ca_bytes))).unwrap();
88+
debug_assert_eq!((1, 0), root_store.add_parsable_certificates(&cert));
89+
90+
let config = ClientConfig::builder()
91+
.with_safe_defaults()
92+
.with_root_certificates(root_store)
93+
.with_no_client_auth();
9194
Ok(TlsConnector::from(Arc::new(config)))
9295
}

Diff for: examples/server/Cargo.toml

+5-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ authors = ["The async-rs developers", "quininer <[email protected]>"]
55
edition = "2018"
66

77
[dependencies]
8-
structopt = "0.3.9"
9-
async-std = "1.5.0"
8+
async-std = "1.11.0"
109
async-tls = { path = "../.." }
11-
rustls = "0.19.0"
12-
webpki = "0.21.3"
10+
futures-lite = "1.12.0"
11+
rustls = "0.20.6"
12+
rustls-pemfile = "1.0"
13+
structopt = "0.3.9"

Diff for: examples/server/src/main.rs

+25-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use async_std::io;
22
use async_std::net::{TcpListener, TcpStream};
3-
use async_std::prelude::*;
3+
use async_std::stream::StreamExt;
44
use async_std::task;
55
use async_tls::TlsAcceptor;
6-
use rustls::internal::pemfile::{certs, rsa_private_keys};
7-
use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
6+
use futures_lite::io::AsyncWriteExt;
7+
use rustls::{Certificate, PrivateKey, ServerConfig};
8+
use rustls_pemfile::{certs, read_one, Item};
89

910
use std::fs::File;
1011
use std::io::BufReader;
@@ -28,14 +29,23 @@ struct Options {
2829

2930
/// Load the passed certificates file
3031
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
31-
certs(&mut BufReader::new(File::open(path)?))
32-
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
32+
Ok(certs(&mut BufReader::new(File::open(path)?))
33+
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?
34+
.into_iter()
35+
.map(Certificate)
36+
.collect())
3337
}
3438

3539
/// Load the passed keys file
36-
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
37-
rsa_private_keys(&mut BufReader::new(File::open(path)?))
38-
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
40+
fn load_key(path: &Path) -> io::Result<PrivateKey> {
41+
match read_one(&mut BufReader::new(File::open(path)?)) {
42+
Ok(Some(Item::RSAKey(data) | Item::PKCS8Key(data))) => Ok(PrivateKey(data)),
43+
Ok(_) => Err(io::Error::new(
44+
io::ErrorKind::InvalidInput,
45+
format!("invalid key in {}", path.display()),
46+
)),
47+
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
48+
}
3949
}
4050

4151
/// Configure the server using rusttls
@@ -44,13 +54,15 @@ fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
4454
/// A TLS server needs a certificate and a fitting private key
4555
fn load_config(options: &Options) -> io::Result<ServerConfig> {
4656
let certs = load_certs(&options.cert)?;
47-
let mut keys = load_keys(&options.key)?;
57+
debug_assert_eq!(1, certs.len());
58+
let key = load_key(&options.key)?;
4859

4960
// we don't use client authentication
50-
let mut config = ServerConfig::new(NoClientAuth::new());
51-
config
61+
let config = ServerConfig::builder()
62+
.with_safe_defaults()
63+
.with_no_client_auth()
5264
// set this server to use one cert together with the loaded private key
53-
.set_single_cert(certs, keys.remove(0))
65+
.with_single_cert(certs, key)
5466
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
5567

5668
Ok(config)
@@ -78,7 +90,7 @@ async fn handle_connection(acceptor: &TlsAcceptor, tcp_stream: &mut TcpStream) -
7890
)
7991
.await?;
8092

81-
tls_stream.flush().await?;
93+
tls_stream.close().await?;
8294

8395
Ok(())
8496
}

Diff for: src/acceptor.rs

+12-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::common::tls_state::TlsState;
22
use crate::server;
33

44
use futures_io::{AsyncRead, AsyncWrite};
5-
use rustls::{ServerConfig, ServerSession};
5+
use rustls::{ServerConfig, ServerConnection};
66
use std::future::Future;
77
use std::io;
88
use std::pin::Pin;
@@ -39,17 +39,23 @@ impl TlsAcceptor {
3939
self.accept_with(stream, |_| ())
4040
}
4141

42-
// Currently private, as exposing ServerSessions exposes rusttls
42+
// Currently private, as exposing ServerConnections exposes rusttls
4343
fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
4444
where
4545
IO: AsyncRead + AsyncWrite + Unpin,
46-
F: FnOnce(&mut ServerSession),
46+
F: FnOnce(&mut ServerConnection),
4747
{
48-
let mut session = ServerSession::new(&self.inner);
49-
f(&mut session);
48+
let mut conn = match ServerConnection::new(self.inner.clone()) {
49+
Ok(conn) => conn,
50+
Err(_) => {
51+
return Accept(server::MidHandshake::End);
52+
}
53+
};
54+
55+
f(&mut conn);
5056

5157
Accept(server::MidHandshake::Handshaking(server::TlsStream {
52-
session,
58+
conn,
5359
io: stream,
5460
state: TlsState::Stream,
5561
}))

Diff for: src/client.rs

+10-9
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@ use crate::common::tls_state::TlsState;
44
use crate::rusttls::stream::Stream;
55
use futures_core::ready;
66
use futures_io::{AsyncRead, AsyncWrite};
7-
use rustls::ClientSession;
7+
use rustls::ClientConnection;
88
use std::future::Future;
99
use std::pin::Pin;
1010
use std::task::{Context, Poll};
1111
use std::{io, mem};
1212

13-
use rustls::Session;
14-
1513
/// The client end of a TLS connection. Can be used like any other bidirectional IO stream.
1614
/// Wraps the underlying TCP stream.
1715
#[derive(Debug)]
1816
pub struct TlsStream<IO> {
1917
pub(crate) io: IO,
20-
pub(crate) session: ClientSession,
18+
pub(crate) session: ClientConnection,
2119
pub(crate) state: TlsState,
2220

2321
#[cfg(feature = "early-data")]
@@ -58,11 +56,11 @@ where
5856
let (io, session) = (&mut stream.io, &mut stream.session);
5957
let mut stream = Stream::new(io, session).set_eof(eof);
6058

61-
if stream.session.is_handshaking() {
59+
if stream.conn.is_handshaking() {
6260
ready!(stream.complete_io(cx))?;
6361
}
6462

65-
if stream.session.wants_write() {
63+
if stream.conn.wants_write() {
6664
ready!(stream.complete_io(cx))?;
6765
}
6866
}
@@ -90,17 +88,20 @@ where
9088
TlsState::EarlyData => {
9189
let this = self.get_mut();
9290

91+
let is_handshaking = this.session.is_handshaking();
92+
let is_early_data_accepted = this.session.is_early_data_accepted();
93+
9394
let mut stream =
9495
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
9596
let (pos, data) = &mut this.early_data;
9697

9798
// complete handshake
98-
if stream.session.is_handshaking() {
99+
if is_handshaking {
99100
ready!(stream.complete_io(cx))?;
100101
}
101102

102103
// write early data (fallback)
103-
if !stream.session.is_early_data_accepted() {
104+
if !is_early_data_accepted {
104105
while *pos < data.len() {
105106
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
106107
*pos += len;
@@ -127,7 +128,7 @@ where
127128
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
128129
this.state.shutdown_read();
129130
if this.state.writeable() {
130-
stream.session.send_close_notify();
131+
stream.conn.send_close_notify();
131132
this.state.shutdown_write();
132133
}
133134
Poll::Ready(Ok(0))

Diff for: src/connector.rs

+27-10
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ use crate::common::tls_state::TlsState;
33
use crate::client;
44

55
use futures_io::{AsyncRead, AsyncWrite};
6-
use rustls::{ClientConfig, ClientSession};
6+
use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName};
7+
use std::convert::TryFrom;
78
use std::future::Future;
89
use std::io;
910
use std::pin::Pin;
1011
use std::sync::Arc;
1112
use std::task::{Context, Poll};
12-
use webpki::DNSNameRef;
1313

1414
/// The TLS connecting part. The acceptor drives
1515
/// the client side of the TLS handshake process. It works
@@ -64,10 +64,18 @@ impl From<ClientConfig> for TlsConnector {
6464

6565
impl Default for TlsConnector {
6666
fn default() -> Self {
67-
let mut config = ClientConfig::new();
68-
config
69-
.root_store
70-
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
67+
let mut root_certs = RootCertStore::empty();
68+
root_certs.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
69+
OwnedTrustAnchor::from_subject_spki_name_constraints(
70+
ta.subject,
71+
ta.spki,
72+
ta.name_constraints,
73+
)
74+
}));
75+
let config = ClientConfig::builder()
76+
.with_safe_defaults()
77+
.with_root_certificates(root_certs)
78+
.with_no_client_auth();
7179
Arc::new(config).into()
7280
}
7381
}
@@ -102,14 +110,14 @@ impl TlsConnector {
102110
self.connect_with(domain, stream, |_| ())
103111
}
104112

105-
// NOTE: Currently private, exposing ClientSession exposes rusttls
113+
// NOTE: Currently private, exposing ClientConnection exposes rusttls
106114
// Early data should be exposed differently
107115
fn connect_with<'a, IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
108116
where
109117
IO: AsyncRead + AsyncWrite + Unpin,
110-
F: FnOnce(&mut ClientSession),
118+
F: FnOnce(&mut ClientConnection),
111119
{
112-
let domain = match DNSNameRef::try_from_ascii_str(domain.as_ref()) {
120+
let domain = match ServerName::try_from(domain.as_ref()) {
113121
Ok(domain) => domain,
114122
Err(_) => {
115123
return Connect(ConnectInner::Error(Some(io::Error::new(
@@ -119,7 +127,16 @@ impl TlsConnector {
119127
}
120128
};
121129

122-
let mut session = ClientSession::new(&self.inner, domain);
130+
let mut session = match ClientConnection::new(self.inner.clone(), domain) {
131+
Ok(session) => session,
132+
Err(_) => {
133+
return Connect(ConnectInner::Error(Some(io::Error::new(
134+
io::ErrorKind::Other,
135+
"invalid connection",
136+
))))
137+
}
138+
};
139+
123140
f(&mut session);
124141

125142
#[cfg(not(feature = "early-data"))]

0 commit comments

Comments
 (0)