Skip to content

Commit efad362

Browse files
authored
feat: implemented cancel frame handling (#49)
Client will send a CANCEL frame for a dropped stream when the next payload is received for this stream.
1 parent 0e4530b commit efad362

File tree

5 files changed

+399
-5
lines changed

5 files changed

+399
-5
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ members = [
99
"examples",
1010
"rsocket-test",
1111
]
12+
13+
[replace]
14+
"rsocket_rust:0.7.1" = { path = "../rsocket-rust/rsocket" }
15+
"rsocket_rust_transport_tcp:0.7.1" = { path = "../rsocket-rust/rsocket-transport-tcp" }

rsocket-test/Cargo.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,19 @@ version = "0.7.1"
3636
version = "1.0.3"
3737
default-features = false
3838
features = ["full"]
39+
40+
[dev-dependencies.tokio-stream]
41+
version = "0.1.7"
42+
features = ["sync"]
43+
44+
[dev-dependencies.anyhow]
45+
version = "1.0.40"
46+
47+
[dev-dependencies.async-trait]
48+
version = "0.1.50"
49+
50+
[dev-dependencies.serial_test]
51+
version = "0.5.1"
52+
53+
[dev-dependencies.async-stream]
54+
version = "0.3.1"
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
#[macro_use]
2+
extern crate log;
3+
4+
use std::sync::Arc;
5+
use std::sync::Mutex;
6+
use std::time::Duration;
7+
8+
use anyhow::Result;
9+
use async_trait::async_trait;
10+
use futures::StreamExt;
11+
use tokio_stream::wrappers::ReceiverStream;
12+
13+
use rsocket_rust::prelude::{Flux, Payload, RSocket};
14+
15+
#[cfg(test)]
16+
mod tests {
17+
use std::time::Duration;
18+
19+
use futures::Future;
20+
use rsocket_rust_transport_websocket::{WebsocketClientTransport, WebsocketServerTransport};
21+
use serial_test::serial;
22+
use tokio::runtime::Runtime;
23+
use async_stream::stream;
24+
use rsocket_rust::Client;
25+
use rsocket_rust::prelude::*;
26+
use rsocket_rust::utils::EchoRSocket;
27+
use rsocket_rust_transport_tcp::{TcpClientTransport, TcpServerTransport, UnixClientTransport, UnixServerTransport};
28+
29+
use crate::TestSocket;
30+
31+
#[serial]
32+
#[test]
33+
fn request_stream_can_be_cancelled_by_client_uds() {
34+
init_logger();
35+
with_uds_test_socket_run(request_stream_can_be_cancelled_by_client);
36+
}
37+
38+
#[serial]
39+
#[test]
40+
fn request_stream_can_be_cancelled_by_client_tcp() {
41+
init_logger();
42+
with_tcp_test_socket_run(request_stream_can_be_cancelled_by_client);
43+
}
44+
45+
#[serial]
46+
#[test]
47+
fn request_stream_can_be_cancelled_by_client_ws() {
48+
init_logger();
49+
with_ws_test_socket_run(request_stream_can_be_cancelled_by_client);
50+
}
51+
52+
///
53+
/// Client requests a channel, consumes an item and drops the stream handle.
54+
///
55+
/// Amount of active streams is verified before and after requesting and after dropping.
56+
///
57+
/// Before request_stream: 0 subscribers
58+
/// When request_stream is called: 1 subscriber
59+
/// When request_stream handle is dropped: 0 subscribers
60+
async fn request_stream_can_be_cancelled_by_client(client: Client) {
61+
assert_eq!(
62+
client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(),
63+
Some("0")
64+
);
65+
66+
let mut results = client.request_stream(Payload::from(""));
67+
let payload = results.next().await.expect("valid payload").unwrap();
68+
assert_eq!(payload.metadata_utf8(), Some("subscribers: 1"));
69+
assert_eq!(payload.data_utf8(), Some("0"));
70+
71+
assert_eq!(
72+
client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(),
73+
Some("1")
74+
);
75+
76+
debug!("when the Flux is dropped");
77+
drop(results);
78+
// Give the server enough time to receive the CANCEL frame
79+
tokio::time::sleep(Duration::from_millis(250)).await;
80+
81+
assert_eq!(
82+
client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(),
83+
Some("0")
84+
);
85+
}
86+
87+
#[serial]
88+
#[test]
89+
fn request_channel_can_be_cancelled_by_client_uds() {
90+
init_logger();
91+
with_uds_test_socket_run(request_channel_can_be_cancelled_by_client);
92+
}
93+
94+
#[serial]
95+
#[test]
96+
fn request_channel_can_be_cancelled_by_client_tcp() {
97+
init_logger();
98+
with_tcp_test_socket_run(request_channel_can_be_cancelled_by_client);
99+
}
100+
101+
#[serial]
102+
#[test]
103+
fn request_channel_can_be_cancelled_by_client_ws() {
104+
init_logger();
105+
with_ws_test_socket_run(request_channel_can_be_cancelled_by_client);
106+
}
107+
108+
///
109+
/// Client requests a stream, consumes an item and drops the stream handle.
110+
///
111+
/// Amount of active streams is verified before and after requesting and after dropping.
112+
///
113+
/// Before request_channel: 0 subscribers
114+
/// When request_channel is called: 1 subscriber
115+
/// When request_channel handle is dropped: 0 subscribers
116+
async fn request_channel_can_be_cancelled_by_client(client: Client) {
117+
assert_eq!(
118+
client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(),
119+
Some("0")
120+
);
121+
122+
let mut results = client.request_channel(
123+
stream!{ yield Ok(Payload::from("")) }.boxed()
124+
);
125+
let payload = results.next().await.expect("valid payload").unwrap();
126+
assert_eq!(payload.metadata_utf8(), Some("subscribers: 1"));
127+
assert_eq!(payload.data_utf8(), Some("0"));
128+
129+
assert_eq!(
130+
client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(),
131+
Some("1")
132+
);
133+
134+
debug!("when the Flux is dropped");
135+
drop(results);
136+
// Give the server enough time to receive the CANCEL frame
137+
tokio::time::sleep(Duration::from_millis(250)).await;
138+
139+
assert_eq!(
140+
client.request_response(Payload::from("subscribers")).await.unwrap().unwrap().data_utf8(),
141+
Some("0")
142+
);
143+
}
144+
145+
fn init_logger() {
146+
let _ = env_logger::builder()
147+
.format_timestamp_millis()
148+
.filter_level(log::LevelFilter::Debug)
149+
// .is_test(true)
150+
.try_init();
151+
}
152+
153+
/// Executes the [run_test] scenario using a client which is connected over a UDS transport to
154+
/// a TestSocket
155+
fn with_uds_test_socket_run<F, Fut>(run_test: F)
156+
where
157+
F: (FnOnce(Client) -> Fut) + Send + 'static,
158+
Fut: Future<Output=()> + Send + 'static,
159+
{
160+
info!("=====> begin uds");
161+
let server_runtime = Runtime::new().unwrap();
162+
163+
server_runtime.spawn(async move {
164+
RSocketFactory::receive()
165+
.transport(UnixServerTransport::from("/tmp/rsocket-uds.sock".to_owned()))
166+
.acceptor(Box::new(|_setup, _socket| { Ok(Box::new(TestSocket::new())) }))
167+
.serve()
168+
.await
169+
});
170+
171+
std::thread::sleep(Duration::from_millis(500));
172+
173+
let client_runtime = Runtime::new().unwrap();
174+
175+
client_runtime.block_on(async {
176+
let client = RSocketFactory::connect()
177+
.acceptor(Box::new(|| Box::new(EchoRSocket)))
178+
.transport(UnixClientTransport::from("/tmp/rsocket-uds.sock".to_owned()))
179+
.setup(Payload::from("READY!"))
180+
.mime_type("text/plain", "text/plain")
181+
.start()
182+
.await
183+
.unwrap();
184+
run_test(client).await;
185+
});
186+
info!("<===== uds done!");
187+
}
188+
189+
/// Executes the [run_test] scenario using a client which is connected over a UDS transport to
190+
/// a TestSocket
191+
fn with_ws_test_socket_run<F, Fut>(run_test: F)
192+
where
193+
F: (FnOnce(Client) -> Fut) + Send + 'static,
194+
Fut: Future<Output=()> + Send + 'static,
195+
{
196+
info!("=====> begin ws");
197+
let server_runtime = Runtime::new().unwrap();
198+
server_runtime.spawn(async move {
199+
RSocketFactory::receive()
200+
.transport(WebsocketServerTransport::from("127.0.0.1:8080".to_owned()))
201+
.acceptor(Box::new(|_setup, _socket| { Ok(Box::new(TestSocket::new())) }))
202+
.serve()
203+
.await
204+
});
205+
206+
std::thread::sleep(Duration::from_millis(500));
207+
208+
let client_runtime = Runtime::new().unwrap();
209+
210+
client_runtime.block_on(async {
211+
let client = RSocketFactory::connect()
212+
.acceptor(Box::new(|| Box::new(EchoRSocket)))
213+
.transport(WebsocketClientTransport::from("127.0.0.1:8080"))
214+
.setup(Payload::from("READY!"))
215+
.mime_type("text/plain", "text/plain")
216+
.start()
217+
.await
218+
.unwrap();
219+
220+
221+
run_test(client).await;
222+
});
223+
info!("<===== ws done!");
224+
}
225+
226+
/// Executes the [run_test] scenario using a client which is connected over a TCP transport to
227+
/// a TestSocket
228+
fn with_tcp_test_socket_run<F, Fut>(run_test: F)
229+
where
230+
F: (FnOnce(Client) -> Fut) + Send + 'static,
231+
Fut: Future<Output=()> + Send + 'static,
232+
{
233+
info!("=====> begin tcp");
234+
let server_runtime = Runtime::new().unwrap();
235+
server_runtime.spawn(async move {
236+
RSocketFactory::receive()
237+
.transport(TcpServerTransport::from("127.0.0.1:7878".to_owned()))
238+
.acceptor(Box::new(|_setup, _socket| { Ok(Box::new(TestSocket::new())) }))
239+
.serve()
240+
.await
241+
});
242+
243+
std::thread::sleep(Duration::from_millis(500));
244+
245+
let client_runtime = Runtime::new().unwrap();
246+
247+
client_runtime.block_on(async {
248+
let client = RSocketFactory::connect()
249+
.acceptor(Box::new(|| Box::new(EchoRSocket)))
250+
.transport(TcpClientTransport::from("127.0.0.1:7878".to_owned()))
251+
.setup(Payload::from("READY!"))
252+
.mime_type("text/plain", "text/plain")
253+
.start()
254+
.await
255+
.unwrap();
256+
run_test(client).await;
257+
});
258+
info!("<===== tpc done!");
259+
}
260+
}
261+
262+
/// Stateful socket for tests, can be used to count active subscribers.
263+
struct TestSocket {
264+
subscribers: Arc<Mutex<u32>>,
265+
}
266+
267+
impl TestSocket {
268+
fn new() -> Self {
269+
TestSocket {
270+
subscribers: Arc::new(Mutex::new(0)),
271+
}
272+
}
273+
274+
fn inc_subscriber_count(subscribers: &Arc<Mutex<u32>>) {
275+
let mut guard = subscribers.lock().unwrap();
276+
*guard = *guard + 1;
277+
info!(target: "TestSocket", "subscribers:({})", guard);
278+
}
279+
280+
fn dec_subscriber_count(subscribers: &Arc<Mutex<u32>>) {
281+
let mut guard = subscribers.lock().unwrap();
282+
*guard = *guard - 1;
283+
info!(target: "TestSocket", "subscribers:({})", guard);
284+
}
285+
}
286+
287+
#[async_trait]
288+
impl RSocket for TestSocket {
289+
async fn metadata_push(&self, _req: Payload) -> Result<()> {
290+
unimplemented!();
291+
}
292+
293+
async fn fire_and_forget(&self, _req: Payload) -> Result<()> {
294+
unimplemented!();
295+
}
296+
297+
async fn request_response(&self, req: Payload) -> Result<Option<Payload>> {
298+
let subscribers = *self.subscribers.lock().unwrap();
299+
let response = match req.data_utf8() {
300+
Some("subscribers") => format!("{}", subscribers),
301+
_ => "Request payload did not contain a known key!".to_owned(),
302+
};
303+
Ok(Some(Payload::builder().set_data_utf8(&response).build()))
304+
}
305+
306+
fn request_stream(&self, _req: Payload) -> Flux<Result<Payload>> {
307+
let (tx, rx) = tokio::sync::mpsc::channel(32);
308+
let subscribers = self.subscribers.clone();
309+
tokio::spawn(async move {
310+
TestSocket::inc_subscriber_count(&subscribers);
311+
312+
for i in 0 as u32..100 {
313+
if tx.is_closed() {
314+
debug!(target: "TestSocket", "tx is closed, break!");
315+
break;
316+
}
317+
let payload = Payload::builder()
318+
.set_data_utf8(format!("{}", i).as_str())
319+
.set_metadata_utf8(format!("subscribers: {}", *subscribers.lock().unwrap()).as_str())
320+
.build();
321+
tx.send(Ok(payload)).await.unwrap();
322+
tokio::time::sleep(Duration::from_millis(50)).await;
323+
}
324+
325+
TestSocket::dec_subscriber_count(&subscribers);
326+
});
327+
328+
ReceiverStream::new(rx).boxed()
329+
}
330+
331+
fn request_channel(&self, _reqs: Flux<Result<Payload>>) -> Flux<Result<Payload>> {
332+
self.request_stream(Payload::from(""))
333+
}
334+
}

rsocket/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ version = "1.0.3"
2929
default-features = false
3030
features = [ "macros", "rt", "rt-multi-thread", "sync", "time" ]
3131

32+
[dependencies.tokio-stream]
33+
version = "0.1.7"
34+
features = ["sync"]
35+
3236
[features]
3337
default = []
3438
frame = []

0 commit comments

Comments
 (0)