Skip to content

Commit 75f9d3e

Browse files
committed
Support body health detection
1 parent 9ed175d commit 75f9d3e

File tree

4 files changed

+202
-38
lines changed

4 files changed

+202
-38
lines changed

Cargo.toml

+23-23
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ repository = "https://github.com/hyperium/hyper"
99
license = "MIT"
1010
authors = ["Sean McArthur <[email protected]>"]
1111
keywords = ["http", "hyper", "hyperium"]
12-
categories = ["network-programming", "web-programming::http-client", "web-programming::http-server"]
12+
categories = [
13+
"network-programming",
14+
"web-programming::http-client",
15+
"web-programming::http-server",
16+
]
1317
edition = "2018"
1418
rust-version = "1.56"
1519

16-
include = [
17-
"Cargo.toml",
18-
"LICENSE",
19-
"src/**/*",
20-
]
20+
include = ["Cargo.toml", "LICENSE", "src/**/*"]
2121

2222
[dependencies]
2323
bytes = "1"
@@ -42,7 +42,9 @@ libc = { version = "0.2", optional = true }
4242
socket2 = { version = "0.4", optional = true }
4343

4444
[dev-dependencies]
45-
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
45+
futures-util = { version = "0.3", default-features = false, features = [
46+
"alloc",
47+
] }
4648
http-body-util = "=0.1.0-rc.2"
4749
matches = "0.1"
4850
num_cpus = "1.0"
@@ -51,16 +53,16 @@ spmc = "0.3"
5153
serde = { version = "1.0", features = ["derive"] }
5254
serde_json = "1.0"
5355
tokio = { version = "1", features = [
54-
"fs",
55-
"macros",
56-
"net",
57-
"io-std",
58-
"io-util",
59-
"rt",
60-
"rt-multi-thread", # so examples can use #[tokio::main]
61-
"sync",
62-
"time",
63-
"test-util",
56+
"fs",
57+
"macros",
58+
"net",
59+
"io-std",
60+
"io-util",
61+
"rt",
62+
"rt-multi-thread", # so examples can use #[tokio::main]
63+
"sync",
64+
"time",
65+
"test-util",
6466
] }
6567
tokio-test = "0.4"
6668
tokio-util = { version = "0.7", features = ["codec"] }
@@ -71,12 +73,7 @@ url = "2.2"
7173
default = []
7274

7375
# Easily turn it all on
74-
full = [
75-
"client",
76-
"http1",
77-
"http2",
78-
"server",
79-
]
76+
full = ["client", "http1", "http2", "server"]
8077

8178
# HTTP versions
8279
http1 = []
@@ -219,3 +216,6 @@ required-features = ["full"]
219216
name = "server"
220217
path = "tests/server.rs"
221218
required-features = ["full"]
219+
220+
[patch.crates-io]
221+
http-body = { git = "https://github.com/sfackler/http-body", branch = "body-poll-alive" }

src/proto/h1/dispatch.rs

+17-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ pub(crate) trait Dispatch {
2828
self: Pin<&mut Self>,
2929
cx: &mut task::Context<'_>,
3030
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>;
31-
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>) -> crate::Result<()>;
31+
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>)
32+
-> crate::Result<()>;
3233
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>;
3334
fn should_poll(&self) -> bool;
3435
}
@@ -249,7 +250,8 @@ where
249250
let body = match body_len {
250251
DecodedLength::ZERO => IncomingBody::empty(),
251252
other => {
252-
let (tx, rx) = IncomingBody::new_channel(other, wants.contains(Wants::EXPECT));
253+
let (tx, rx) =
254+
IncomingBody::new_channel(other, wants.contains(Wants::EXPECT));
253255
self.body_tx = Some(tx);
254256
rx
255257
}
@@ -317,7 +319,19 @@ where
317319
return Poll::Ready(Ok(()));
318320
}
319321
} else if !self.conn.can_buffer_body() {
320-
ready!(self.poll_flush(cx))?;
322+
if self.poll_flush(cx)?.is_pending() {
323+
// If we're not able to make progress, check the body health
324+
if let (Some(body), clear_body) =
325+
OptGuard::new(self.body_rx.as_mut()).guard_mut()
326+
{
327+
body.poll_healthy(cx).map_err(|e| {
328+
*clear_body = true;
329+
crate::Error::new_user_body(e)
330+
})?;
331+
}
332+
333+
return Poll::Pending;
334+
}
321335
} else {
322336
// A new scope is needed :(
323337
if let (Some(mut body), clear_body) =

src/proto/h2/mod.rs

+17-12
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,29 @@ where
126126

127127
if me.body_tx.capacity() == 0 {
128128
loop {
129-
match ready!(me.body_tx.poll_capacity(cx)) {
130-
Some(Ok(0)) => {}
131-
Some(Ok(_)) => break,
132-
Some(Err(e)) => {
129+
match me.body_tx.poll_capacity(cx) {
130+
Poll::Ready(Some(Ok(0))) => {}
131+
Poll::Ready(Some(Ok(_))) => break,
132+
Poll::Ready(Some(Err(e))) => {
133133
return Poll::Ready(Err(crate::Error::new_body_write(e)))
134134
}
135-
None => {
135+
Poll::Ready(None) => {
136136
// None means the stream is no longer in a
137137
// streaming state, we either finished it
138138
// somehow, or the remote reset us.
139139
return Poll::Ready(Err(crate::Error::new_body_write(
140140
"send stream capacity unexpectedly closed",
141141
)));
142142
}
143+
Poll::Pending => {
144+
// If we're not able to make progress, check if the body is healthy
145+
me.stream
146+
.as_mut()
147+
.poll_healthy(cx)
148+
.map_err(|e| me.body_tx.on_user_err(e))?;
149+
150+
return Poll::Pending;
151+
}
143152
}
144153
}
145154
} else if let Poll::Ready(reason) = me
@@ -148,9 +157,7 @@ where
148157
.map_err(crate::Error::new_body_write)?
149158
{
150159
debug!("stream received RST_STREAM: {:?}", reason);
151-
return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(
152-
reason,
153-
))));
160+
return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(reason))));
154161
}
155162

156163
match ready!(me.stream.as_mut().poll_frame(cx)) {
@@ -365,14 +372,12 @@ where
365372
cx: &mut Context<'_>,
366373
) -> Poll<Result<(), io::Error>> {
367374
if self.send_stream.write(&[], true).is_ok() {
368-
return Poll::Ready(Ok(()))
375+
return Poll::Ready(Ok(()));
369376
}
370377

371378
Poll::Ready(Err(h2_to_io_error(
372379
match ready!(self.send_stream.poll_reset(cx)) {
373-
Ok(Reason::NO_ERROR) => {
374-
return Poll::Ready(Ok(()))
375-
}
380+
Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())),
376381
Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
377382
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
378383
}

tests/server.rs

+145
Original file line numberDiff line numberDiff line change
@@ -1737,6 +1737,151 @@ async fn http_connect_new() {
17371737
assert_eq!(s(&vec), "bar=foo");
17381738
}
17391739

1740+
struct UnhealthyBody {
1741+
rx: oneshot::Receiver<()>,
1742+
tx: Option<oneshot::Sender<()>>,
1743+
}
1744+
1745+
impl Body for UnhealthyBody {
1746+
type Data = Bytes;
1747+
1748+
type Error = &'static str;
1749+
1750+
fn poll_frame(
1751+
self: Pin<&mut Self>,
1752+
_cx: &mut Context<'_>,
1753+
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1754+
Poll::Ready(Some(Ok(http_body::Frame::data(Bytes::from_static(
1755+
&[0; 1024],
1756+
)))))
1757+
}
1758+
1759+
fn poll_healthy(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), Self::Error> {
1760+
if Pin::new(&mut self.rx).poll(cx).is_pending() {
1761+
return Ok(());
1762+
}
1763+
1764+
let _ = self.tx.take().unwrap().send(());
1765+
Err("blammo")
1766+
}
1767+
}
1768+
1769+
#[tokio::test]
1770+
async fn h1_unhealthy_body() {
1771+
let (listener, addr) = setup_tcp_listener();
1772+
let (unhealthy_tx, unhealthy_rx) = oneshot::channel();
1773+
let (read_body_tx, read_body_rx) = oneshot::channel();
1774+
1775+
let client = tokio::spawn(async move {
1776+
let mut tcp = connect_async(addr).await;
1777+
tcp.write_all(
1778+
b"\
1779+
GET / HTTP/1.1\r\n\
1780+
\r\n\
1781+
Host: localhost\r\n\
1782+
\r\n
1783+
",
1784+
)
1785+
.await
1786+
.expect("write 1");
1787+
1788+
let mut buf = [0; 1024];
1789+
loop {
1790+
let nread = tcp.read(&mut buf).await.expect("read 1");
1791+
if buf[..nread].contains(&0) {
1792+
break;
1793+
}
1794+
}
1795+
1796+
read_body_tx.send(()).unwrap();
1797+
unhealthy_rx.await.expect("rx");
1798+
1799+
while tcp.read(&mut buf).await.expect("read") > 0 {}
1800+
});
1801+
1802+
let mut read_body_rx = Some(read_body_rx);
1803+
let mut unhealthy_tx = Some(unhealthy_tx);
1804+
let svc = service_fn(move |_: Request<IncomingBody>| {
1805+
future::ok::<_, &'static str>(
1806+
Response::builder()
1807+
.status(200)
1808+
.body(UnhealthyBody {
1809+
rx: read_body_rx.take().unwrap(),
1810+
tx: unhealthy_tx.take(),
1811+
})
1812+
.unwrap(),
1813+
)
1814+
});
1815+
1816+
let (socket, _) = listener.accept().await.unwrap();
1817+
let err = http1::Builder::new()
1818+
.serve_connection(socket, svc)
1819+
.await
1820+
.err()
1821+
.unwrap();
1822+
assert!(err.to_string().contains("blammo"));
1823+
1824+
client.await.unwrap();
1825+
}
1826+
1827+
#[tokio::test]
1828+
async fn h2_unhealthy_body() {
1829+
let (listener, addr) = setup_tcp_listener();
1830+
let (unhealthy_tx, unhealthy_rx) = oneshot::channel();
1831+
let (read_body_tx, read_body_rx) = oneshot::channel();
1832+
1833+
let client = tokio::spawn(async move {
1834+
let tcp = connect_async(addr).await;
1835+
let (h2, connection) = h2::client::handshake(tcp).await.unwrap();
1836+
tokio::spawn(async move {
1837+
connection.await.unwrap();
1838+
});
1839+
let mut h2 = h2.ready().await.unwrap();
1840+
1841+
let request = Request::get("/").body(()).unwrap();
1842+
let (response, _) = h2.send_request(request, true).unwrap();
1843+
1844+
let mut body = response.await.unwrap().into_body();
1845+
1846+
let bytes = body.data().await.unwrap().unwrap();
1847+
let _ = body.flow_control().release_capacity(bytes.len());
1848+
1849+
read_body_tx.send(()).unwrap();
1850+
unhealthy_rx.await.unwrap();
1851+
1852+
loop {
1853+
let bytes = match body.data().await.transpose() {
1854+
Ok(Some(bytes)) => bytes,
1855+
Ok(None) => panic!(),
1856+
Err(_) => break,
1857+
};
1858+
let _ = body.flow_control().release_capacity(bytes.len());
1859+
}
1860+
});
1861+
1862+
let mut read_body_rx = Some(read_body_rx);
1863+
let mut unhealthy_tx = Some(unhealthy_tx);
1864+
let svc = service_fn(move |_: Request<IncomingBody>| {
1865+
future::ok::<_, &'static str>(
1866+
Response::builder()
1867+
.status(200)
1868+
.body(UnhealthyBody {
1869+
rx: read_body_rx.take().unwrap(),
1870+
tx: unhealthy_tx.take(),
1871+
})
1872+
.unwrap(),
1873+
)
1874+
});
1875+
1876+
let (socket, _) = listener.accept().await.unwrap();
1877+
http2::Builder::new(TokioExecutor)
1878+
.serve_connection(socket, svc)
1879+
.await
1880+
.unwrap();
1881+
1882+
client.await.unwrap();
1883+
}
1884+
17401885
#[tokio::test]
17411886
async fn h2_connect() {
17421887
let (listener, addr) = setup_tcp_listener();

0 commit comments

Comments
 (0)