Skip to content

Commit 5de79e6

Browse files
committed
Subscriptions keep clients alive until dropped
1 parent ffd5020 commit 5de79e6

File tree

11 files changed

+113
-82
lines changed

11 files changed

+113
-82
lines changed

client/http-client/src/client.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,16 @@ where
483483
> + Send
484484
+ Sync,
485485
{
486+
type SubscriptionClient = Self;
487+
486488
/// Send a subscription request to the server. Not implemented for HTTP; will always return
487489
/// [`Error::HttpNotImplemented`].
488490
fn subscribe<'a, N, Params>(
489491
&self,
490492
_subscribe_method: &'a str,
491493
_params: Params,
492494
_unsubscribe_method: &'a str,
493-
) -> impl Future<Output = Result<Subscription<N>, Error>>
495+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>>
494496
where
495497
Params: ToRpcParams + Send,
496498
N: DeserializeOwned,
@@ -499,7 +501,10 @@ where
499501
}
500502

501503
/// Subscribe to a specific method. Not implemented for HTTP; will always return [`Error::HttpNotImplemented`].
502-
fn subscribe_to_method<N>(&self, _method: &str) -> impl Future<Output = Result<Subscription<N>, Error>>
504+
fn subscribe_to_method<N>(
505+
&self,
506+
_method: &str,
507+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>>
503508
where
504509
N: DeserializeOwned,
505510
{

client/ws-client/src/tests.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async fn subscription_works() {
159159
let uri = to_ws_uri_string(server.local_addr());
160160
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
161161
{
162-
let mut sub: Subscription<String> = client
162+
let mut sub: Subscription<_, String> = client
163163
.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello")
164164
.with_default_timeout()
165165
.await
@@ -183,7 +183,7 @@ async fn notification_handler_works() {
183183
let uri = to_ws_uri_string(server.local_addr());
184184
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
185185
{
186-
let mut nh: Subscription<String> =
186+
let mut nh: Subscription<_, String> =
187187
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
188188
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
189189
assert_eq!("server originated notification works".to_owned(), response);
@@ -203,7 +203,7 @@ async fn notification_no_params() {
203203
let uri = to_ws_uri_string(server.local_addr());
204204
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
205205
{
206-
let mut nh: Subscription<serde_json::Value> =
206+
let mut nh: Subscription<_, serde_json::Value> =
207207
client.subscribe_to_method("no_params").with_default_timeout().await.unwrap().unwrap();
208208
let response = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
209209
assert_eq!(response, serde_json::Value::Null);
@@ -244,15 +244,15 @@ async fn batched_notifs_works() {
244244
// Ensure that subscription is returned back to the correct handle
245245
// and is handled separately from ordinary notifications.
246246
{
247-
let mut nh: Subscription<String> =
247+
let mut nh: Subscription<_, String> =
248248
client.subscribe("sub", rpc_params![], "unsub").with_default_timeout().await.unwrap().unwrap();
249249
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
250250
assert_eq!("sub_notif", response);
251251
}
252252

253253
// Ensure that method notif is returned back to the correct handle.
254254
{
255-
let mut nh: Subscription<String> =
255+
let mut nh: Subscription<_, String> =
256256
client.subscribe_to_method("sub").with_default_timeout().await.unwrap().unwrap();
257257
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
258258
assert_eq!("method_notif", response);
@@ -279,7 +279,7 @@ async fn notification_close_on_lagging() {
279279
.await
280280
.unwrap()
281281
.unwrap();
282-
let mut nh: Subscription<String> =
282+
let mut nh: Subscription<_, String> =
283283
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
284284

285285
// Don't poll the notification stream for 2 seconds, should be full now.
@@ -297,7 +297,7 @@ async fn notification_close_on_lagging() {
297297
assert!(nh.next().with_default_timeout().await.unwrap().is_none());
298298

299299
// The same subscription should be possible to register again.
300-
let mut other_nh: Subscription<String> =
300+
let mut other_nh: Subscription<_, String> =
301301
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
302302

303303
// check that the new subscription works.

core/src/client/async_client/mod.rs

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -360,14 +360,16 @@ impl<L> ClientBuilder<L> {
360360

361361
tokio::spawn(wait_for_shutdown(send_receive_task_sync_rx, client_dropped_rx, disconnect_reason.clone()));
362362

363-
Client {
363+
let inner = ClientInner {
364364
to_back: to_back.clone(),
365365
service: self.service_builder.service(RpcService::new(to_back.clone())),
366366
request_timeout: self.request_timeout,
367367
error: ErrorFromBack::new(to_back, disconnect_reason),
368368
id_manager: RequestIdManager::new(self.id_kind),
369369
on_exit: Some(client_dropped_tx),
370-
}
370+
};
371+
372+
Client { inner: Arc::new(inner) }
371373
}
372374

373375
/// Build the client with given transport.
@@ -419,20 +421,21 @@ impl<L> ClientBuilder<L> {
419421
disconnect_reason.clone(),
420422
));
421423

422-
Client {
424+
let inner = ClientInner {
423425
to_back: to_back.clone(),
424426
service: self.service_builder.service(RpcService::new(to_back.clone())),
425427
request_timeout: self.request_timeout,
426428
error: ErrorFromBack::new(to_back, disconnect_reason),
427429
id_manager: RequestIdManager::new(self.id_kind),
428430
on_exit: Some(client_dropped_tx),
429-
}
431+
};
432+
433+
Client { inner: Arc::new(inner) }
430434
}
431435
}
432436

433-
/// Generic asynchronous client.
434437
#[derive(Debug)]
435-
pub struct Client<L = RpcLogger<RpcService>> {
438+
struct ClientInner<L = RpcLogger<RpcService>> {
436439
/// Channel to send requests to the background task.
437440
to_back: mpsc::Sender<FrontToBack>,
438441
error: ErrorFromBack,
@@ -445,6 +448,21 @@ pub struct Client<L = RpcLogger<RpcService>> {
445448
service: L,
446449
}
447450

451+
impl<L> Drop for ClientInner<L> {
452+
fn drop(&mut self) {
453+
if let Some(e) = self.on_exit.take() {
454+
let _ = e.send(());
455+
}
456+
}
457+
}
458+
459+
/// Generic asynchronous client.
460+
#[derive(Debug)]
461+
#[repr(transparent)]
462+
pub struct Client<L = RpcLogger<RpcService>> {
463+
inner: Arc<ClientInner<L>>
464+
}
465+
448466
impl Client<Identity> {
449467
/// Create a builder for the client.
450468
pub fn builder() -> ClientBuilder {
@@ -455,13 +473,13 @@ impl Client<Identity> {
455473
impl<L> Client<L> {
456474
/// Checks if the client is connected to the target.
457475
pub fn is_connected(&self) -> bool {
458-
!self.to_back.is_closed()
476+
!self.inner.to_back.is_closed()
459477
}
460478

461479
async fn run_future_until_timeout<T>(&self, fut: impl Future<Output = Result<T, Error>>) -> Result<T, Error> {
462480
tokio::pin!(fut);
463481

464-
match futures_util::future::select(fut, futures_timer::Delay::new(self.request_timeout)).await {
482+
match futures_util::future::select(fut, futures_timer::Delay::new(self.inner.request_timeout)).await {
465483
Either::Left((Ok(r), _)) => Ok(r),
466484
Either::Left((Err(Error::ServiceDisconnect), _)) => Err(self.on_disconnect().await),
467485
Either::Left((Err(e), _)) => Err(e),
@@ -476,20 +494,18 @@ impl<L> Client<L> {
476494
///
477495
/// This method is cancel safe.
478496
pub async fn on_disconnect(&self) -> Error {
479-
self.error.read_error().await
497+
self.inner.error.read_error().await
480498
}
481499

482500
/// Returns configured request timeout.
483501
pub fn request_timeout(&self) -> Duration {
484-
self.request_timeout
502+
self.inner.request_timeout
485503
}
486504
}
487505

488-
impl<L> Drop for Client<L> {
489-
fn drop(&mut self) {
490-
if let Some(e) = self.on_exit.take() {
491-
let _ = e.send(());
492-
}
506+
impl<L> Clone for Client<L> {
507+
fn clone(&self) -> Self {
508+
Self { inner: self.inner.clone() }
493509
}
494510
}
495511

@@ -508,9 +524,9 @@ where
508524
{
509525
async {
510526
// NOTE: we use this to guard against max number of concurrent requests.
511-
let _req_id = self.id_manager.next_request_id();
527+
let _req_id = self.inner.id_manager.next_request_id();
512528
let params = params.to_rpc_params()?.map(StdCow::Owned);
513-
let fut = self.service.notification(jsonrpsee_types::Notification::new(method.into(), params));
529+
let fut = self.inner.service.notification(jsonrpsee_types::Notification::new(method.into(), params));
514530
self.run_future_until_timeout(fut).await?;
515531
Ok(())
516532
}
@@ -522,9 +538,9 @@ where
522538
Params: ToRpcParams + Send,
523539
{
524540
async {
525-
let id = self.id_manager.next_request_id();
541+
let id = self.inner.id_manager.next_request_id();
526542
let params = params.to_rpc_params()?;
527-
let fut = self.service.call(Request::borrowed(method, params.as_deref(), id.clone()));
543+
let fut = self.inner.service.call(Request::borrowed(method, params.as_deref(), id.clone()));
528544
let rp = self.run_future_until_timeout(fut).await?;
529545
let success = ResponseSuccess::try_from(rp.into_response().into_inner())?;
530546

@@ -541,15 +557,15 @@ where
541557
{
542558
async {
543559
let batch = batch.build()?;
544-
let id = self.id_manager.next_request_id();
560+
let id = self.inner.id_manager.next_request_id();
545561
let id_range = generate_batch_id_range(id, batch.len() as u64)?;
546562

547563
let mut b = Batch::with_capacity(batch.len());
548564

549565
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
550566
b.push(Request {
551567
jsonrpc: TwoPointZero,
552-
id: self.id_manager.as_id_kind().into_id(id),
568+
id: self.inner.id_manager.as_id_kind().into_id(id),
553569
method: method.into(),
554570
params: params.map(StdCow::Owned),
555571
extensions: Extensions::new(),
@@ -558,7 +574,7 @@ where
558574

559575
b.extensions_mut().insert(IsBatch { id_range });
560576

561-
let fut = self.service.batch(b);
577+
let fut = self.inner.service.batch(b);
562578
let json_values = self.run_future_until_timeout(fut).await?;
563579

564580
let mut responses = Vec::with_capacity(json_values.len());
@@ -592,6 +608,8 @@ where
592608
> + Send
593609
+ Sync,
594610
{
611+
type SubscriptionClient = Self;
612+
595613
/// Send a subscription request to the server.
596614
///
597615
/// The `subscribe_method` and `params` are used to ask for the subscription towards the
@@ -601,7 +619,7 @@ where
601619
subscribe_method: &'a str,
602620
params: Params,
603621
unsubscribe_method: &'a str,
604-
) -> impl Future<Output = Result<Subscription<Notif>, Error>> + Send
622+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, Notif>, Error>> + Send
605623
where
606624
Params: ToRpcParams + Send,
607625
Notif: DeserializeOwned,
@@ -611,8 +629,8 @@ where
611629
return Err(RegisterMethodError::SubscriptionNameConflict(unsubscribe_method.to_owned()).into());
612630
}
613631

614-
let req_id_sub = self.id_manager.next_request_id();
615-
let req_id_unsub = self.id_manager.next_request_id();
632+
let req_id_sub = self.inner.id_manager.next_request_id();
633+
let req_id_unsub = self.inner.id_manager.next_request_id();
616634
let params = params.to_rpc_params()?;
617635

618636
let mut ext = Extensions::new();
@@ -626,24 +644,25 @@ where
626644
extensions: ext,
627645
};
628646

629-
let fut = self.service.call(req);
647+
let fut = self.inner.service.call(req);
630648
let sub = self
631649
.run_future_until_timeout(fut)
632650
.await?
633651
.into_subscription()
634652
.expect("Extensions set to subscription, must return subscription; qed");
635-
Ok(Subscription::new(self.to_back.clone(), sub.stream, SubscriptionKind::Subscription(sub.sub_id)))
653+
Ok(Subscription::new(self.clone(), self.inner.to_back.clone(), sub.stream, SubscriptionKind::Subscription(sub.sub_id)))
636654
}
637655
}
638656

639657
/// Subscribe to a specific method.
640-
fn subscribe_to_method<N>(&self, method: &str) -> impl Future<Output = Result<Subscription<N>, Error>> + Send
658+
fn subscribe_to_method<N>(&self, method: &str) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>> + Send
641659
where
642660
N: DeserializeOwned,
643661
{
644662
async {
645663
let (send_back_tx, send_back_rx) = oneshot::channel();
646664
if self
665+
.inner
647666
.to_back
648667
.clone()
649668
.send(FrontToBack::RegisterNotification(RegisterNotificationMessage {
@@ -656,15 +675,15 @@ where
656675
return Err(self.on_disconnect().await);
657676
}
658677

659-
let res = call_with_timeout(self.request_timeout, send_back_rx).await;
678+
let res = call_with_timeout(self.inner.request_timeout, send_back_rx).await;
660679

661680
let (rx, method) = match res {
662681
Ok(Ok(val)) => val,
663682
Ok(Err(err)) => return Err(err),
664683
Err(_) => return Err(self.on_disconnect().await),
665684
};
666685

667-
Ok(Subscription::new(self.to_back.clone(), rx, SubscriptionKind::Method(method)))
686+
Ok(Subscription::new(self.clone(), self.inner.to_back.clone(), rx, SubscriptionKind::Method(method)))
668687
}
669688
}
670689
}

0 commit comments

Comments
 (0)