diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 486719d4..64c3c88b 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -129,7 +129,7 @@ tracing-subscriber = { version = "0.3", features = [ async-trait = "0.1" [[test]] name = "test_tool_macros" -required-features = ["server"] +required-features = ["server", "client"] path = "tests/test_tool_macros.rs" [[test]] diff --git a/crates/rmcp/src/handler/client.rs b/crates/rmcp/src/handler/client.rs index 3e91a7e2..005661ca 100644 --- a/crates/rmcp/src/handler/client.rs +++ b/crates/rmcp/src/handler/client.rs @@ -1,7 +1,7 @@ use crate::{ error::Error as McpError, model::*, - service::{Peer, RequestContext, RoleClient, Service, ServiceRole}, + service::{RequestContext, RoleClient, Service, ServiceRole}, }; impl Service for H { @@ -118,47 +118,16 @@ pub trait ClientHandler: Sized + Send + Sync + 'static { std::future::ready(()) } - fn get_peer(&self) -> Option>; - - fn set_peer(&mut self, peer: Peer); - fn get_info(&self) -> ClientInfo { ClientInfo::default() } } -/// Do nothing, just store the peer. -impl ClientHandler for Option> { - fn get_peer(&self) -> Option> { - self.clone() - } - - fn set_peer(&mut self, peer: Peer) { - *self = Some(peer); - } -} - -/// Do nothing, even store the peer. -impl ClientHandler for () { - fn get_peer(&self) -> Option> { - None - } - - fn set_peer(&mut self, peer: Peer) { - drop(peer); - } -} +/// Do nothing, with default client info. +impl ClientHandler for () {} -/// Do nothing, even store the peer. +/// Do nothing, with a specific client info. impl ClientHandler for ClientInfo { - fn get_peer(&self) -> Option> { - None - } - - fn set_peer(&mut self, peer: Peer) { - drop(peer); - } - fn get_info(&self) -> ClientInfo { self.clone() } diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 98199282..83e9e57f 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -1,7 +1,7 @@ use crate::{ error::Error as McpError, model::*, - service::{Peer, RequestContext, RoleServer, Service, ServiceRole}, + service::{RequestContext, RoleServer, Service, ServiceRole}, }; mod resource; @@ -108,6 +108,10 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { request: InitializeRequestParam, context: RequestContext, ) -> impl Future> + Send + '_ { + if context.peer.peer_info().is_none() { + context.peer.set_peer_info(request); + } + let info = self.get_info(); std::future::ready(Ok(self.get_info())) } fn complete( @@ -210,14 +214,6 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { std::future::ready(()) } - fn get_peer(&self) -> Option> { - None - } - - fn set_peer(&mut self, peer: Peer) { - drop(peer); - } - fn get_info(&self) -> ServerInfo { ServerInfo::default() } diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 12037982..30d88727 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -86,6 +86,9 @@ impl<'service, S> ToolCallContext<'service, S> { pub fn name(&self) -> &str { &self.name } + pub fn request_context(&self) -> &RequestContext { + &self.request_context + } } pub trait FromToolCallContextPart<'a, S>: Sized { @@ -284,6 +287,39 @@ impl<'a, S> FromToolCallContextPart<'a, S> for JsonObject { } } +impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Extensions { + fn from_tool_call_context_part( + context: ToolCallContext<'a, S>, + ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + let extensions = context.request_context.extensions.clone(); + Ok((extensions, context)) + } +} + +pub struct Extension(pub T); + +impl<'a, S, T> FromToolCallContextPart<'a, S> for Extension +where + T: Send + Sync + 'static + Clone, +{ + fn from_tool_call_context_part( + context: ToolCallContext<'a, S>, + ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + let extension = context + .request_context + .extensions + .get::() + .cloned() + .ok_or_else(|| { + crate::Error::invalid_params( + format!("missing extension {}", std::any::type_name::()), + None, + ) + })?; + Ok((Extension(extension), context)) + } +} + impl<'s, S> ToolCallContext<'s, S> { pub fn invoke(self, h: H) -> H::Fut where diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 56edead8..c1a6ecc0 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -303,7 +303,7 @@ pub struct Peer { tx: mpsc::Sender>, request_id_provider: Arc, progress_token_provider: Arc, - info: Arc, + info: Arc>, } impl std::fmt::Debug for Peer { @@ -333,7 +333,7 @@ impl Peer { const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024; pub(crate) fn new( request_id_provider: Arc, - peer_info: R::PeerInfo, + peer_info: Option, ) -> (Peer, ProxyOutbound) { let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE); ( @@ -341,7 +341,7 @@ impl Peer { tx, request_id_provider, progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()), - info: peer_info.into(), + info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)), }, rx, ) @@ -402,8 +402,16 @@ impl Peer { peer: self.clone(), }) } - pub fn peer_info(&self) -> &R::PeerInfo { - &self.info + pub fn peer_info(&self) -> Option<&R::PeerInfo> { + self.info.get() + } + + pub fn set_peer_info(&self, info: R::PeerInfo) { + if self.info.initialized() { + tracing::warn!("trying to set peer info, which is already initialized"); + } else { + let _ = self.info.set(info); + } } pub fn is_transport_closed(&self) -> bool { @@ -469,7 +477,7 @@ pub struct RequestContext { pub async fn serve_directly( service: S, transport: T, - peer_info: R::PeerInfo, + peer_info: Option, ) -> RunningService where R: ServiceRole, @@ -484,7 +492,7 @@ where pub async fn serve_directly_with_ct( service: S, transport: T, - peer_info: R::PeerInfo, + peer_info: Option, ct: CancellationToken, ) -> RunningService where diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 9c43f83b..ab4c3495 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -174,7 +174,7 @@ where error, context: "send initialized notification".into(), })?; - let (peer, peer_rx) = Peer::new(id_provider, initialize_result); + let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result)); Ok(serve_inner(service, transport, peer, peer_rx, ct).await) } diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index d40241d3..6825282b 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -156,7 +156,7 @@ where ClientJsonRpcMessage::request(request, id), ))); }; - let (peer, peer_rx) = Peer::new(id_provider, peer_info.params.clone()); + let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone())); let context = RequestContext { ct: ct.child_token(), id: id.clone(), diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 8f47501a..ec1b03ce 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -57,7 +57,7 @@ //! //! // create transport from std io //! async fn io() -> Result<(), Box> { -//! let client = None.serve((tokio::io::stdin(), tokio::io::stdout())).await?; +//! let client = ().serve((tokio::io::stdin(), tokio::io::stdout())).await?; //! let tools = client.peer().list_tools(Default::default()).await?; //! println!("{:?}", tools); //! Ok(()) diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 2bc8ec86..4e3e5d43 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -18,7 +18,7 @@ use tracing::Instrument; use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, - service::{RxJsonRpcMessage, TxJsonRpcMessage}, + service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, transport::common::axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, }; @@ -67,7 +67,7 @@ async fn post_event_handler( parts: Parts, Json(mut message): Json, ) -> Result { - tracing::debug!(session_id, ?message, "new client message"); + tracing::debug!(session_id, ?parts, ?message, "new client message"); let tx = { let rg = app.txs.read().await; rg.get(session_id.as_str()) @@ -84,9 +84,10 @@ async fn post_event_handler( async fn sse_handler( State(app): State, + parts: Parts, ) -> Result>>, Response> { let session = session_id(); - tracing::info!(%session, "sse connection"); + tracing::info!(%session, ?parts, "sse connection"); use tokio_stream::{StreamExt, wrappers::ReceiverStream}; use tokio_util::sync::PollSender; let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64); @@ -300,6 +301,27 @@ impl SseServer { ct } + /// This allows you to skip the initialization steps for incoming request. + pub fn with_service_directly(mut self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + let ct = self.config.ct.clone(); + tokio::spawn(async move { + while let Some(transport) = self.next_transport().await { + let service = service_provider(); + let ct = self.config.ct.child_token(); + tokio::spawn(async move { + let server = serve_directly_with_ct(service, transport, None, ct).await; + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + ct + } + pub fn cancel(&self) { self.config.ct.cancel(); } diff --git a/crates/rmcp/tests/common/handlers.rs b/crates/rmcp/tests/common/handlers.rs index d2212b63..c769565f 100644 --- a/crates/rmcp/tests/common/handlers.rs +++ b/crates/rmcp/tests/common/handlers.rs @@ -4,16 +4,14 @@ use std::{ }; use rmcp::{ - ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, - model::*, - service::{Peer, RequestContext}, + ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, model::*, + service::RequestContext, }; use serde_json::json; use tokio::sync::Notify; #[derive(Clone)] pub struct TestClientHandler { - pub peer: Option>, pub honor_this_server: bool, pub honor_all_servers: bool, pub receive_signal: Arc, @@ -24,7 +22,6 @@ impl TestClientHandler { #[allow(dead_code)] pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self { Self { - peer: None, honor_this_server, honor_all_servers, receive_signal: Arc::new(Notify::new()), @@ -40,7 +37,6 @@ impl TestClientHandler { received_messages: Arc>>, ) -> Self { Self { - peer: None, honor_this_server, honor_all_servers, receive_signal, @@ -50,14 +46,6 @@ impl TestClientHandler { } impl ClientHandler for TestClientHandler { - fn get_peer(&self) -> Option> { - self.peer.clone() - } - - fn set_peer(&mut self, peer: Peer) { - self.peer = Some(peer); - } - async fn create_message( &self, params: CreateMessageRequestParam, diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs index 4d4c0f6e..09dd5e56 100644 --- a/crates/rmcp/tests/test_notification.rs +++ b/crates/rmcp/tests/test_notification.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use rmcp::{ - ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt, + ClientHandler, ServerHandler, ServiceExt, model::{ ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam, }, @@ -49,7 +49,6 @@ impl ServerHandler for Server { pub struct Client { receive_signal: Arc, - peer: Option>, } impl ClientHandler for Client { @@ -58,14 +57,6 @@ impl ClientHandler for Client { tracing::info!("Resource updated: {}", uri); self.receive_signal.notify_one(); } - - fn set_peer(&mut self, peer: Peer) { - self.peer.replace(peer); - } - - fn get_peer(&self) -> Option> { - self.peer.clone() - } } #[tokio::test] @@ -85,7 +76,6 @@ async fn test_server_notification() -> anyhow::Result<()> { }); let receive_signal = Arc::new(Notify::new()); let client = Client { - peer: Default::default(), receive_signal: receive_signal.clone(), } .serve(client_transport) diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 2e7e214c..669839b8 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use rmcp::{ - ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt, + ClientHandler, ServerHandler, ServiceExt, handler::server::tool::ToolCallContext, model::{CallToolRequestParam, ClientInfo}, tool, @@ -242,22 +242,12 @@ fn test_optional_field_schema_generation_via_macro() { // Define a dummy client handler #[derive(Debug, Clone, Default)] -struct DummyClientHandler { - peer: Option>, -} +struct DummyClientHandler {} impl ClientHandler for DummyClientHandler { fn get_info(&self) -> ClientInfo { ClientInfo::default() } - - fn set_peer(&mut self, peer: Peer) { - self.peer = Some(peer); - } - - fn get_peer(&self) -> Option> { - self.peer.clone() - } } #[tokio::test] diff --git a/examples/servers/src/axum.rs b/examples/servers/src/axum.rs index e1146265..f70320e2 100644 --- a/examples/servers/src/axum.rs +++ b/examples/servers/src/axum.rs @@ -21,7 +21,7 @@ async fn main() -> anyhow::Result<()> { let ct = SseServer::serve(BIND_ADDRESS.parse()?) .await? - .with_service(Counter::new); + .with_service_directly(Counter::new); tokio::signal::ctrl_c().await?; ct.cancel();