Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions juniper_graphql_ws/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@ All user visible changes to `juniper_graphql_ws` crate will be documented in thi

- `Schema::Context` now requires `Clone` bound for ability to have a "fresh" context value each time a new [GraphQL] operation is started in a [WebSocket] connection. ([#1369])
> **COMPATIBILITY**: Previously, it was `Arc`ed inside, sharing the same context value across all [GraphQL] operations of a [WebSocket] connection. To preserve the previous behavior, the `Schema::Context` type should be either wrapped into `Arc` or made `Arc`-based internally.
- Replaced `ConnectionConfig::keep_alive_interval` option with `ConnectionConfig::keep_alive` one as `KeepAliveConfig`. ([#1367])
- Made [WebSocket] connection closed once `ConnectionConfig::keep_alive::timeout` is reached in [`graphql-transport-ws` GraphQL over WebSocket Protocol][proto-6.0.7]. ([#1367])
> **COMPATIBILITY**: Previously, a [WebSocket] connection was kept alive, even when clients do not respond to server's `Pong` messages at all. To preserve the previous behavior, the `ConnectionConfig::keep_alive::timeout` should be set to `Duration:::ZERO`.

### Fixed

- Inability to re-subscribe with the same operation `id` after subscription was completed by server. ([#1368])

[#1367]: /../../pull/1367
[#1368]: /../../pull/1368
[#1369]: /../../pull/1369
[proto-6.0.7]: https://github.com/enisdenjo/graphql-ws/blob/v6.0.7/PROTOCOL.md



Expand Down
2 changes: 1 addition & 1 deletion juniper_graphql_ws/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ derive_more = { version = "2.0", features = ["debug", "from"] }
juniper = { version = "0.17", path = "../juniper", default-features = false }
juniper_subscriptions = { version = "0.18", path = "../juniper_subscriptions" }
serde = { version = "1.0.122", features = ["derive"], default-features = false }
tokio = { version = "1.0", features = ["macros", "rt", "time"], default-features = false }
tokio = { version = "1.0", features = ["macros", "rt", "sync", "time"], default-features = false }

[dev-dependencies]
serde_json = "1.0.18"
Expand Down
98 changes: 81 additions & 17 deletions juniper_graphql_ws/src/graphql_transport_ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use juniper::{
task::{Context, Poll, Waker},
},
};
use tokio::{sync::Notify, time};

use super::{ConnectionConfig, Init, Schema};

Expand Down Expand Up @@ -83,6 +84,7 @@ enum ConnectionState<S: Schema, I: Init<S::ScalarValue, S::Context>> {
Active {
config: ConnectionConfig<S::Context>,
stoppers: HashMap<String, oneshot::Sender<()>>,
ping: Arc<Notify>,
schema: S,
},
/// Terminated is the state after a ConnectionInit message has been rejected.
Expand All @@ -100,60 +102,89 @@ impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
Self::PreInit { init, schema } => match msg {
ClientMessage::ConnectionInit { payload } => match init.init(payload).await {
Ok(config) => {
let keep_alive_interval = config.keep_alive_interval;
let keep_alive_interval = config.keep_alive.interval;
let keep_alive_timeout = config.keep_alive.timeout;

let mut s =
stream::iter(vec![Output::Message(ServerMessage::ConnectionAck)])
.boxed();
let ping = Arc::new(Notify::new());

let mut s = Output::Message(ServerMessage::ConnectionAck)
.into_stream()
.boxed();

#[expect(closure_returning_async_block, reason = "not possible")]
if keep_alive_interval > Duration::from_secs(0) {
s = s
.chain(Output::Message(ServerMessage::Pong).into_stream())
.boxed();
s = s
.chain(stream::unfold((), move |_| async move {
tokio::time::sleep(keep_alive_interval).await;
Some((Output::Message(ServerMessage::Pong), ()))
.chain(stream::repeat(()).then(move |()| {
tokio::time::sleep(keep_alive_interval)
.map(|()| Output::Message(ServerMessage::Pong))
}))
.boxed();
}

if keep_alive_timeout > Duration::from_secs(0) {
let ping_rx = ping.clone();
s = stream::select_all([
s,
stream::repeat(())
.then(move |()| {
let ping_rx = ping_rx.clone();
async move {
time::timeout(keep_alive_timeout, ping_rx.notified())
.await
.is_err()
.then(|| Output::Close {
code: 1000,
message: "Connection lost unexpectedly".into(),
})
}
})
.filter_map(future::ready)
.boxed(),
])
.boxed();
}

(
Self::Active {
config,
stoppers: HashMap::new(),
ping,
schema,
},
s,
)
}
Err(e) => (
Self::Terminated,
stream::iter(vec![Output::Close {
Output::Close {
code: 4403,
message: e.to_string(),
}])
}
.into_stream()
.boxed(),
),
},
ClientMessage::Ping { .. } => (
Self::PreInit { init, schema },
stream::iter(vec![Output::Message(ServerMessage::Pong)]).boxed(),
Output::Message(ServerMessage::Pong).into_stream().boxed(),
),
ClientMessage::Subscribe { .. } => (
Self::PreInit { init, schema },
stream::iter(vec![Output::Close {
Output::Close {
code: 4401,
message: "Unauthorized".to_string(),
}])
}
.into_stream()
.boxed(),
),
_ => (Self::PreInit { init, schema }, stream::empty().boxed()),
},
Self::Active {
config,
mut stoppers,
ping,
schema,
} => {
let reactions = match msg {
Expand Down Expand Up @@ -225,14 +256,16 @@ impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
stream::empty().boxed()
}
ClientMessage::Ping { .. } => {
stream::iter(vec![Output::Message(ServerMessage::Pong)]).boxed()
ping.notify_waiters();
Output::Message(ServerMessage::Pong).into_stream().boxed()
}
_ => stream::empty().boxed(),
};
(
Self::Active {
config,
stoppers,
ping,
schema,
},
reactions,
Expand Down Expand Up @@ -956,10 +989,12 @@ mod test {
}

#[tokio::test]
async fn test_keep_alives() {
async fn test_keep_alive_interval() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_millis(20)),
ConnectionConfig::new(Context(1))
.with_keep_alive_interval(Duration::from_millis(20))
.with_keep_alive_timeout(Duration::from_secs(0)),
);

conn.send(ClientMessage::ConnectionInit {
Expand All @@ -981,6 +1016,35 @@ mod test {
}
}

#[tokio::test]
async fn test_keep_alive_timeout() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1))
.with_keep_alive_interval(Duration::from_millis(0))
.with_keep_alive_timeout(Duration::from_millis(20)),
);

conn.send(ClientMessage::ConnectionInit {
payload: graphql_vars! {},
})
.await
.unwrap();

assert_eq!(
Output::Message(ServerMessage::ConnectionAck),
conn.next().await.unwrap()
);

assert_eq!(
Output::Close {
code: 1000,
message: "Connection lost unexpectedly".into(),
},
conn.next().await.unwrap(),
);
}

#[tokio::test]
async fn test_slow_init() {
let mut conn = Connection::new(
Expand Down Expand Up @@ -1009,7 +1073,7 @@ mod test {

assert_eq!(
Output::Message(ServerMessage::ConnectionAck),
conn.next().await.unwrap()
conn.next().await.unwrap(),
);

assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion juniper_graphql_ws/src/graphql_ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
Self::PreInit { init, schema } => match msg {
ClientMessage::ConnectionInit { payload } => match init.init(payload).await {
Ok(config) => {
let keep_alive_interval = config.keep_alive_interval;
let keep_alive_interval = config.keep_alive.interval;

let mut s = stream::iter(vec![Reaction::ServerMessage(
ServerMessage::ConnectionAck,
Expand Down
79 changes: 71 additions & 8 deletions juniper_graphql_ws/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@ pub struct ConnectionConfig<CtxT> {
/// By default, there is no limit to in-flight operations.
pub max_in_flight_operations: usize,

/// Interval at which to send keep-alives.
///
/// Specifying a [`Duration::ZERO`] will disable keep-alives.
///
/// By default, keep-alives are sent every 15 seconds.
pub keep_alive_interval: Duration,
/// Keep-alive configuration.
pub keep_alive: KeepAliveConfig,
}

impl<CtxT> ConnectionConfig<CtxT> {
Expand All @@ -48,7 +44,7 @@ impl<CtxT> ConnectionConfig<CtxT> {
Self {
context,
max_in_flight_operations: 0,
keep_alive_interval: Duration::from_secs(15),
keep_alive: KeepAliveConfig::default(),
}
}

Expand All @@ -66,10 +62,37 @@ impl<CtxT> ConnectionConfig<CtxT> {
///
/// Specifying a [`Duration::ZERO`] will disable keep-alives.
///
/// Also, sets a keep-alive timeout to the provided [`Duration`].
///
/// By default, keep-alives are sent every 15 seconds.
#[must_use]
pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self {
self.keep_alive_interval = interval;
self.keep_alive.interval = interval;
#[cfg(feature = "graphql-transport-ws")]
{
self.keep_alive.timeout = interval;
}
self
}

#[cfg(feature = "graphql-transport-ws")]
/// Specifies the timeout for waiting a keep-alive response from clients after sending them a
/// keep-alive message.
///
/// Once the timeout is hit, the connection is closed by the server.
///
/// Specifying a [`Duration::ZERO`] disables timeout checking.
///
/// Applicable only for the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new],
/// and does nothing for the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old].
///
/// By default, timeout equals to the [`KeepAliveConfig::interval`].
///
/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md
#[must_use]
pub fn with_keep_alive_timeout(mut self, timeout: Duration) -> Self {
self.keep_alive.timeout = timeout;
self
}
}
Expand All @@ -83,6 +106,46 @@ impl<S: ScalarValue, CtxT: Unpin + Send + 'static> Init<S, CtxT> for ConnectionC
}
}

/// Config for keeping a connection alive.
#[derive(Clone, Copy, Debug)]
pub struct KeepAliveConfig {
/// Interval at which to send keep-alives.
///
/// Specifying a [`Duration::ZERO`] disables keep-alives.
///
/// By default, keep-alives are sent every 15 seconds.
pub interval: Duration,

#[cfg(feature = "graphql-transport-ws")]
/// Timeout for waiting a keep-alive response from clients after sending them a keep-alive
/// message.
///
/// Once the timeout is hit, the connection is closed by the server.
///
/// Specifying a [`Duration::ZERO`] disables timeout checking.
///
/// Applicable only for the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new],
/// and does nothing for the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old].
///
/// By default, timeout equals to the [`interval`].
///
/// [`interval`]: Self::interval
/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md
pub timeout: Duration,
}

impl Default for KeepAliveConfig {
fn default() -> Self {
let interval = Duration::from_secs(15);
Self {
interval,
#[cfg(feature = "graphql-transport-ws")]
timeout: interval,
}
}
}

/// Init defines the requirements for types that can provide connection configurations when
/// ConnectionInit messages are received. Implementations are provided for `ConnectionConfig` and
/// closures that meet the requirements.
Expand Down