diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/client.rs index 01991056..095e6bc5 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/client.rs @@ -285,6 +285,52 @@ impl Client { .await } + /// When receive server's emitwithack callback event, invoke socket.ack(..) function can react to server with ack signal + /// use futures_util::FutureExt; + /// + /// # Example + /// ``` + /// use futures_util::FutureExt; + /// use rust_socketio::{asynchronous::{ClientBuilder, Client}, Payload}; + /// use serde_json::json; + /// use std::time::Duration; + /// use std::thread; + /// use bytes::Bytes; + /// + /// #[tokio::main] + /// async fn main() { + /// + /// let callback = |payload: Payload, socket: Client| { + /// async move { + /// let byte_test = vec![0x01, 0x02]; + /// let _ = socket.ack(byte_test).await; + /// }.boxed() + /// }; + /// + /// // get a socket that is connected to the admin namespace + /// let socket = ClientBuilder::new("http://localhost:4200") + /// .namespace("/") + /// .on("foo", callback) + /// .on("error", |err, _| { + /// async move { eprintln!("Error: {:#?}", err) }.boxed() + /// }) + /// .connect() + /// .await + /// .expect("Connection failed"); + /// + /// + /// thread::sleep(Duration::from_millis(30000)); + /// socket.disconnect().await.expect("Disconnect failed"); + /// } + /// ``` + #[inline] + pub async fn ack(&self, data: D) -> Result<()> + where + D: Into, + { + self.socket.read().await.ack(&self.nsp, data.into()).await + } + /// Disconnects this client from the server by sending a `socket.io` closing /// packet. /// # Example diff --git a/socketio/src/asynchronous/socket.rs b/socketio/src/asynchronous/socket.rs index 81a9ebcd..9010ce45 100644 --- a/socketio/src/asynchronous/socket.rs +++ b/socketio/src/asynchronous/socket.rs @@ -14,7 +14,7 @@ use std::{ fmt::Debug, pin::Pin, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicI32, Ordering}, Arc, }, }; @@ -24,16 +24,20 @@ pub(crate) struct Socket { engine_client: Arc, connected: Arc, generator: StreamGenerator, + ack_id: Arc, } impl Socket { /// Creates an instance of `Socket`. pub(super) fn new(engine_client: EngineClient) -> Result { let connected = Arc::new(AtomicBool::default()); + let ack_id = Arc::new(AtomicI32::new(-1)); + Ok(Socket { engine_client: Arc::new(engine_client.clone()), connected: connected.clone(), - generator: StreamGenerator::new(Self::stream(engine_client, connected)), + ack_id: ack_id.clone(), + generator: StreamGenerator::new(Self::stream(engine_client, connected, ack_id)), }) } @@ -58,6 +62,9 @@ impl Socket { if self.connected.load(Ordering::Acquire) { self.connected.store(false, Ordering::Release); } + if self.ack_id.load(Ordering::Acquire) != -1 { + self.ack_id.store(-1, Ordering::Release); + } Ok(()) } @@ -89,9 +96,17 @@ impl Socket { self.send(socket_packet).await } + /// Emits to connected other side with given data + pub async fn ack(&self, nsp: &str, data: Payload) -> Result<()> { + let socket_packet = + Packet::ack_from_payload(data, nsp, Some(self.ack_id.load(Ordering::Acquire)))?; + self.send(socket_packet).await + } + fn stream( client: EngineClient, is_connected: Arc, + ack_id: Arc, ) -> Pin> + Send>> { Box::pin(try_stream! { for await received_data in client.clone() { @@ -101,6 +116,10 @@ impl Socket { || packet.packet_id == EnginePacketId::MessageBinary { let packet = Self::handle_engineio_packet(packet, client.clone()).await?; + + if ack_id.load(Ordering::Acquire) != packet.id.unwrap_or(-1) { + ack_id.store(packet.id.unwrap_or(-1), Ordering::Release); + } Self::handle_socketio_packet(&packet, is_connected.clone()); yield packet; diff --git a/socketio/src/client/raw_client.rs b/socketio/src/client/raw_client.rs index 9c8ecef2..37ebdd70 100644 --- a/socketio/src/client/raw_client.rs +++ b/socketio/src/client/raw_client.rs @@ -82,6 +82,42 @@ impl RawClient { Ok(()) } + /// Example code for handling ACK response when calling emitWithAck on the server + /// + /// # Example + /// ``` + /// use rust_socketio::{ClientBuilder, Payload, RawClient}; + /// use std::time::Duration; + /// use std::thread::sleep; + /// + /// + /// let ack_callback = |message: Payload, socket: RawClient| { + /// match message { + /// Payload::Text(values) => println!("{:#?}", values), + /// Payload::Binary(bytes) => println!("Received bytes: {:#?}", bytes), + /// // This is deprecated, use Payload::Text instead + /// Payload::String(str) => println!("{}", str), + /// } + /// socket.ack("foo").unwrap(); + /// }; + /// + /// let mut socket = ClientBuilder::new("http://localhost:4200/") + /// .on("foo", ack_callback) + /// .connect() + /// .expect("connection failed"); + /// + /// + /// + /// sleep(Duration::from_secs(2)); + /// ``` + #[inline] + pub fn ack(&self, data: D) -> Result<()> + where + D: Into, + { + self.socket.ack(&self.nsp, data.into()) + } + /// Sends a message to the server using the underlying `engine.io` protocol. /// This message takes an event, which could either be one of the common /// events like "message" or "error" or a custom event like "foo". But be diff --git a/socketio/src/packet.rs b/socketio/src/packet.rs index e74dedb5..50c95c52 100644 --- a/socketio/src/packet.rs +++ b/socketio/src/packet.rs @@ -88,6 +88,53 @@ impl Packet { } } } + + #[inline] + pub(crate) fn ack_from_payload<'a>( + payload: Payload, + nsp: &'a str, + ack_id: Option, + ) -> Result { + match payload { + Payload::Binary(bin_data) => Ok(Packet::new( + PacketId::BinaryAck, + nsp.to_owned(), + None, + ack_id, + 1, + Some(vec![bin_data]), + )), + #[allow(deprecated)] + Payload::String(str_data) => { + let payload = if serde_json::from_str::(&str_data).is_ok() { + format!("[{str_data}]") + } else { + format!("[\"{str_data}\"]") + }; + + Ok(Packet::new( + PacketId::Ack, + nsp.to_owned(), + Some(payload), + ack_id, + 0, + None, + )) + } + Payload::Text(data) => { + let payload = serde_json::Value::Array(data).to_string(); + + Ok(Packet::new( + PacketId::Ack, + nsp.to_owned(), + Some(payload), + ack_id, + 0, + None, + )) + } + } + } } impl Default for Packet { @@ -671,4 +718,60 @@ mod test { } ) } + + #[test] + fn ack_from_payload_binary() { + let payload = Payload::Binary(Bytes::from_static(&[0, 4, 9])); + let result = Packet::ack_from_payload(payload.clone(), "namespace", None).unwrap(); + assert_eq!( + result, + Packet { + packet_type: PacketId::BinaryAck, + nsp: "namespace".to_owned(), + data: None, + id: None, + attachment_count: 1, + attachments: Some(vec![Bytes::from_static(&[0, 4, 9])]), + } + ) + } + + #[test] + #[allow(deprecated)] + fn ack_from_payload_string() { + let payload = Payload::String("test".to_owned()); + let result = + Packet::ack_from_payload(payload.clone(), "other_namespace", Some(10)).unwrap(); + assert_eq!( + result, + Packet { + packet_type: PacketId::Ack, + nsp: "other_namespace".to_owned(), + data: Some("[\"test\"]".to_owned()), + id: Some(10), + attachment_count: 0, + attachments: None, + } + ) + } + + #[test] + fn ack_from_payload_json() { + let payload = Payload::Text(vec![ + serde_json::json!("String test"), + serde_json::json!({"type":"object"}), + ]); + let result = Packet::ack_from_payload(payload.clone(), "/", Some(10)).unwrap(); + assert_eq!( + result, + Packet { + packet_type: PacketId::Ack, + nsp: "/".to_owned(), + data: Some("[\"String test\",{\"type\":\"object\"}]".to_owned()), + id: Some(10), + attachment_count: 0, + attachments: None, + } + ) + } } diff --git a/socketio/src/socket.rs b/socketio/src/socket.rs index b881bad0..8bb411ed 100644 --- a/socketio/src/socket.rs +++ b/socketio/src/socket.rs @@ -3,6 +3,7 @@ use crate::packet::{Packet, PacketId}; use bytes::Bytes; use rust_engineio::{Client as EngineClient, Packet as EnginePacket, PacketId as EnginePacketId}; use std::convert::TryFrom; +use std::sync::atomic::AtomicI32; use std::sync::{atomic::AtomicBool, Arc}; use std::{fmt::Debug, sync::atomic::Ordering}; @@ -14,15 +15,19 @@ pub(crate) struct Socket { //TODO: 0.4.0 refactor this engine_client: Arc, connected: Arc, + ack_id: Arc, } impl Socket { /// Creates an instance of `Socket`. pub(super) fn new(engine_client: EngineClient) -> Result { + let ack_id = Arc::new(AtomicI32::new(-1)); + Ok(Socket { engine_client: Arc::new(engine_client), connected: Arc::new(AtomicBool::default()), + ack_id: ack_id.clone(), }) } @@ -47,6 +52,9 @@ impl Socket { if self.connected.load(Ordering::Acquire) { self.connected.store(false, Ordering::Release); } + if self.ack_id.load(Ordering::Acquire) != -1 { + self.ack_id.store(-1, Ordering::Release); + } Ok(()) } @@ -78,6 +86,13 @@ impl Socket { self.send(socket_packet) } + /// Emits to connected other side with given data + pub fn ack(&self, nsp: &str, data: Payload) -> Result<()> { + let socket_packet = + Packet::ack_from_payload(data, nsp, Some(self.ack_id.load(Ordering::Acquire)))?; + self.send(socket_packet) + } + pub(crate) fn poll(&self) -> Result> { loop { match self.engine_client.poll() { @@ -86,6 +101,10 @@ impl Socket { || packet.packet_id == EnginePacketId::MessageBinary { let packet = self.handle_engineio_packet(packet)?; + if self.ack_id.load(Ordering::Acquire) != packet.id.unwrap_or(-1) { + self.ack_id + .store(packet.id.unwrap_or(-1), Ordering::Release); + } self.handle_socketio_packet(&packet); return Ok(Some(packet)); } else {