Skip to content

feat: better http server support #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ tracing-subscriber = { version = "0.3", features = [
async-trait = "0.1"
[[test]]
name = "test_tool_macros"
required-features = ["server"]
required-features = ["server", "client"]
path = "tests/test_tool_macros.rs"

[[test]]
Expand Down
39 changes: 4 additions & 35 deletions crates/rmcp/src/handler/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
error::Error as McpError,
model::*,
service::{Peer, RequestContext, RoleClient, Service, ServiceRole},
service::{RequestContext, RoleClient, Service, ServiceRole},
};

impl<H: ClientHandler> Service<RoleClient> for H {
Expand Down Expand Up @@ -118,47 +118,16 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
std::future::ready(())
}

fn get_peer(&self) -> Option<Peer<RoleClient>>;

fn set_peer(&mut self, peer: Peer<RoleClient>);

fn get_info(&self) -> ClientInfo {
ClientInfo::default()
}
}

/// Do nothing, just store the peer.
impl ClientHandler for Option<Peer<RoleClient>> {
fn get_peer(&self) -> Option<Peer<RoleClient>> {
self.clone()
}

fn set_peer(&mut self, peer: Peer<RoleClient>) {
*self = Some(peer);
}
}

/// Do nothing, even store the peer.
impl ClientHandler for () {
fn get_peer(&self) -> Option<Peer<RoleClient>> {
None
}

fn set_peer(&mut self, peer: Peer<RoleClient>) {
drop(peer);
}
}
/// Do nothing, with default client info.
impl ClientHandler for () {}

/// Do nothing, even store the peer.
/// Do nothing, with a specific client info.
impl ClientHandler for ClientInfo {
fn get_peer(&self) -> Option<Peer<RoleClient>> {
None
}

fn set_peer(&mut self, peer: Peer<RoleClient>) {
drop(peer);
}

fn get_info(&self) -> ClientInfo {
self.clone()
}
Expand Down
14 changes: 5 additions & 9 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
error::Error as McpError,
model::*,
service::{Peer, RequestContext, RoleServer, Service, ServiceRole},
service::{RequestContext, RoleServer, Service, ServiceRole},
};

mod resource;
Expand Down Expand Up @@ -108,6 +108,10 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
request: InitializeRequestParam,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
if context.peer.peer_info().is_none() {
context.peer.set_peer_info(request);
}
let info = self.get_info();
std::future::ready(Ok(self.get_info()))
}
fn complete(
Expand Down Expand Up @@ -210,14 +214,6 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
std::future::ready(())
}

fn get_peer(&self) -> Option<Peer<RoleServer>> {
None
}

fn set_peer(&mut self, peer: Peer<RoleServer>) {
drop(peer);
}

fn get_info(&self) -> ServerInfo {
ServerInfo::default()
}
Expand Down
36 changes: 36 additions & 0 deletions crates/rmcp/src/handler/server/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ impl<'service, S> ToolCallContext<'service, S> {
pub fn name(&self) -> &str {
&self.name
}
pub fn request_context(&self) -> &RequestContext<RoleServer> {
&self.request_context
}
}

pub trait FromToolCallContextPart<'a, S>: Sized {
Expand Down Expand Up @@ -284,6 +287,39 @@ impl<'a, S> FromToolCallContextPart<'a, S> for JsonObject {
}
}

impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Extensions {
fn from_tool_call_context_part(
context: ToolCallContext<'a, S>,
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
let extensions = context.request_context.extensions.clone();
Ok((extensions, context))
}
}

pub struct Extension<T>(pub T);

impl<'a, S, T> FromToolCallContextPart<'a, S> for Extension<T>
where
T: Send + Sync + 'static + Clone,
{
fn from_tool_call_context_part(
context: ToolCallContext<'a, S>,
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
let extension = context
.request_context
.extensions
.get::<T>()
.cloned()
.ok_or_else(|| {
crate::Error::invalid_params(
format!("missing extension {}", std::any::type_name::<T>()),
None,
)
})?;
Ok((Extension(extension), context))
}
}

impl<'s, S> ToolCallContext<'s, S> {
pub fn invoke<H, A>(self, h: H) -> H::Fut
where
Expand Down
22 changes: 15 additions & 7 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ pub struct Peer<R: ServiceRole> {
tx: mpsc::Sender<PeerSinkMessage<R>>,
request_id_provider: Arc<dyn RequestIdProvider>,
progress_token_provider: Arc<dyn ProgressTokenProvider>,
info: Arc<R::PeerInfo>,
info: Arc<tokio::sync::OnceCell<R::PeerInfo>>,
}

impl<R: ServiceRole> std::fmt::Debug for Peer<R> {
Expand Down Expand Up @@ -333,15 +333,15 @@ impl<R: ServiceRole> Peer<R> {
const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024;
pub(crate) fn new(
request_id_provider: Arc<dyn RequestIdProvider>,
peer_info: R::PeerInfo,
peer_info: Option<R::PeerInfo>,
) -> (Peer<R>, ProxyOutbound<R>) {
let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE);
(
Self {
tx,
request_id_provider,
progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
info: peer_info.into(),
info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)),
},
rx,
)
Expand Down Expand Up @@ -402,8 +402,16 @@ impl<R: ServiceRole> Peer<R> {
peer: self.clone(),
})
}
pub fn peer_info(&self) -> &R::PeerInfo {
&self.info
pub fn peer_info(&self) -> Option<&R::PeerInfo> {
self.info.get()
}

pub fn set_peer_info(&self, info: R::PeerInfo) {
if self.info.initialized() {
tracing::warn!("trying to set peer info, which is already initialized");
} else {
let _ = self.info.set(info);
}
}

pub fn is_transport_closed(&self) -> bool {
Expand Down Expand Up @@ -469,7 +477,7 @@ pub struct RequestContext<R: ServiceRole> {
pub async fn serve_directly<R, S, T, E, A>(
service: S,
transport: T,
peer_info: R::PeerInfo,
peer_info: Option<R::PeerInfo>,
) -> RunningService<R, S>
where
R: ServiceRole,
Expand All @@ -484,7 +492,7 @@ where
pub async fn serve_directly_with_ct<R, S, T, E, A>(
service: S,
transport: T,
peer_info: R::PeerInfo,
peer_info: Option<R::PeerInfo>,
ct: CancellationToken,
) -> RunningService<R, S>
where
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/service/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ where
error,
context: "send initialized notification".into(),
})?;
let (peer, peer_rx) = Peer::new(id_provider, initialize_result);
let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/service/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ where
ClientJsonRpcMessage::request(request, id),
)));
};
let (peer, peer_rx) = Peer::new(id_provider, peer_info.params.clone());
let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone()));
let context = RequestContext {
ct: ct.child_token(),
id: id.clone(),
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
//!
//! // create transport from std io
//! async fn io() -> Result<(), Box<dyn std::error::Error>> {
//! let client = None.serve((tokio::io::stdin(), tokio::io::stdout())).await?;
//! let client = ().serve((tokio::io::stdin(), tokio::io::stdout())).await?;
//! let tools = client.peer().list_tools(Default::default()).await?;
//! println!("{:?}", tools);
//! Ok(())
Expand Down
28 changes: 25 additions & 3 deletions crates/rmcp/src/transport/sse_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tracing::Instrument;
use crate::{
RoleServer, Service,
model::ClientJsonRpcMessage,
service::{RxJsonRpcMessage, TxJsonRpcMessage},
service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct},
transport::common::axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id},
};

Expand Down Expand Up @@ -67,7 +67,7 @@ async fn post_event_handler(
parts: Parts,
Json(mut message): Json<ClientJsonRpcMessage>,
) -> Result<StatusCode, StatusCode> {
tracing::debug!(session_id, ?message, "new client message");
tracing::debug!(session_id, ?parts, ?message, "new client message");
let tx = {
let rg = app.txs.read().await;
rg.get(session_id.as_str())
Expand All @@ -84,9 +84,10 @@ async fn post_event_handler(

async fn sse_handler(
State(app): State<App>,
parts: Parts,
) -> Result<Sse<impl Stream<Item = Result<Event, io::Error>>>, Response<String>> {
let session = session_id();
tracing::info!(%session, "sse connection");
tracing::info!(%session, ?parts, "sse connection");
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
use tokio_util::sync::PollSender;
let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64);
Expand Down Expand Up @@ -300,6 +301,27 @@ impl SseServer {
ct
}

/// This allows you to skip the initialization steps for incoming request.
pub fn with_service_directly<S, F>(mut self, service_provider: F) -> CancellationToken
where
S: Service<RoleServer>,
F: Fn() -> S + Send + 'static,
{
let ct = self.config.ct.clone();
tokio::spawn(async move {
while let Some(transport) = self.next_transport().await {
let service = service_provider();
let ct = self.config.ct.child_token();
tokio::spawn(async move {
let server = serve_directly_with_ct(service, transport, None, ct).await;
server.waiting().await?;
tokio::io::Result::Ok(())
});
}
});
ct
}

pub fn cancel(&self) {
self.config.ct.cancel();
}
Expand Down
16 changes: 2 additions & 14 deletions crates/rmcp/tests/common/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ use std::{
};

use rmcp::{
ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler,
model::*,
service::{Peer, RequestContext},
ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, model::*,
service::RequestContext,
};
use serde_json::json;
use tokio::sync::Notify;

#[derive(Clone)]
pub struct TestClientHandler {
pub peer: Option<Peer<RoleClient>>,
pub honor_this_server: bool,
pub honor_all_servers: bool,
pub receive_signal: Arc<Notify>,
Expand All @@ -24,7 +22,6 @@ impl TestClientHandler {
#[allow(dead_code)]
pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self {
Self {
peer: None,
honor_this_server,
honor_all_servers,
receive_signal: Arc::new(Notify::new()),
Expand All @@ -40,7 +37,6 @@ impl TestClientHandler {
received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
) -> Self {
Self {
peer: None,
honor_this_server,
honor_all_servers,
receive_signal,
Expand All @@ -50,14 +46,6 @@ impl TestClientHandler {
}

impl ClientHandler for TestClientHandler {
fn get_peer(&self) -> Option<Peer<RoleClient>> {
self.peer.clone()
}

fn set_peer(&mut self, peer: Peer<RoleClient>) {
self.peer = Some(peer);
}

async fn create_message(
&self,
params: CreateMessageRequestParam,
Expand Down
12 changes: 1 addition & 11 deletions crates/rmcp/tests/test_notification.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use rmcp::{
ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt,
ClientHandler, ServerHandler, ServiceExt,
model::{
ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam,
},
Expand Down Expand Up @@ -49,7 +49,6 @@ impl ServerHandler for Server {

pub struct Client {
receive_signal: Arc<Notify>,
peer: Option<Peer<RoleClient>>,
}

impl ClientHandler for Client {
Expand All @@ -58,14 +57,6 @@ impl ClientHandler for Client {
tracing::info!("Resource updated: {}", uri);
self.receive_signal.notify_one();
}

fn set_peer(&mut self, peer: Peer<RoleClient>) {
self.peer.replace(peer);
}

fn get_peer(&self) -> Option<Peer<RoleClient>> {
self.peer.clone()
}
}

#[tokio::test]
Expand All @@ -85,7 +76,6 @@ async fn test_server_notification() -> anyhow::Result<()> {
});
let receive_signal = Arc::new(Notify::new());
let client = Client {
peer: Default::default(),
receive_signal: receive_signal.clone(),
}
.serve(client_transport)
Expand Down
Loading
Loading