diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index bf0e0b33..3da42bca 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -80,6 +80,9 @@ impl<'service, S> ToolCallContext<'service, S> { pub fn name(&self) -> &str { &self.name } + pub fn request_context(&self) -> &RequestContext { + &self.request_context + } } pub trait FromToolCallContextPart<'a, S>: Sized { diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 7dce63c1..18790c1a 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration}; use axum::{ Json, Router, extract::{Query, State}, - http::{StatusCode, request::Parts}, + http::{self, StatusCode, request::Parts}, response::{ Response, sse::{Event, KeepAlive, Sse}, @@ -74,7 +74,8 @@ async fn post_event_handler( .ok_or(StatusCode::NOT_FOUND)? .clone() }; - message.insert_extension(parts); + let headers_to_insert: http::HeaderMap = parts.headers.clone(); + message.insert_extension(headers_to_insert); if tx.send(message).await.is_err() { tracing::error!("send message error"); return Err(StatusCode::GONE); diff --git a/crates/rmcp/src/transport/streamable_http_server/axum.rs b/crates/rmcp/src/transport/streamable_http_server/axum.rs index 385df578..0f7bed73 100644 --- a/crates/rmcp/src/transport/streamable_http_server/axum.rs +++ b/crates/rmcp/src/transport/streamable_http_server/axum.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration}; use axum::{ Json, Router, extract::State, - http::{HeaderMap, HeaderValue, StatusCode, request::Parts}, + http::{self, HeaderMap, HeaderValue, StatusCode, request::Parts}, response::{ IntoResponse, Response, sse::{Event, KeepAlive, Sse}, @@ -84,8 +84,8 @@ async fn post_handler( .ok_or((StatusCode::NOT_FOUND, "session not found").into_response())?; session.handle().clone() }; - // inject request part - message.insert_extension(parts); + let headers_to_insert: http::HeaderMap = parts.headers.clone(); + message.insert_extension(headers_to_insert); match &message { ClientJsonRpcMessage::Request(_) | ClientJsonRpcMessage::BatchRequest(_) => { let receiver = handle.establish_request_wise_channel().await.map_err(|e| { diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 12aa8a4a..b8c8771c 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,12 +1,12 @@ use std::sync::Arc; +use crate::common::extractor::ReqHeaders; use rmcp::{ Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars, service::RequestContext, tool, }; use serde_json::json; use tokio::sync::Mutex; - #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct StructRequest { pub a: i32, @@ -81,6 +81,18 @@ impl Counter { (a + b).to_string(), )])) } + #[tool(description = "Get the request headers")] + fn get_headers(&self, ReqHeaders(headers): ReqHeaders) -> Result { + let mut header_strings = Vec::new(); + for (name, value) in headers.iter() { + if let Ok(value_str) = value.to_str() { + header_strings.push(format!("{}: {}", name, value_str)); + } + } + Ok(CallToolResult::success(vec![Content::text( + header_strings.join("\n"), + )])) + } } const_string!(Echo = "echo"); #[tool(tool_box)] diff --git a/examples/servers/src/common/extractor.rs b/examples/servers/src/common/extractor.rs new file mode 100644 index 00000000..d016cf84 --- /dev/null +++ b/examples/servers/src/common/extractor.rs @@ -0,0 +1,20 @@ +use axum::http::HeaderMap; +use rmcp::Error as McpError; +use rmcp::handler::server::tool::{FromToolCallContextPart, ToolCallContext}; + +#[derive(Debug)] +pub struct ReqHeaders(pub HeaderMap); + +impl<'a, S> FromToolCallContextPart<'a, S> for ReqHeaders { + fn from_tool_call_context_part( + context: ToolCallContext<'a, S>, + ) -> Result<(Self, ToolCallContext<'a, S>), McpError> { + match context.request_context().extensions.get::() { + Some(headers) => Ok((ReqHeaders(headers.clone()), context)), + None => Err(McpError::internal_error( + "HTTP headers not found in context.", + None, + )), + } + } +} diff --git a/examples/servers/src/common/mod.rs b/examples/servers/src/common/mod.rs index 5919bccd..c84dc821 100644 --- a/examples/servers/src/common/mod.rs +++ b/examples/servers/src/common/mod.rs @@ -1,3 +1,4 @@ pub mod calculator; pub mod counter; +pub mod extractor; pub mod generic_service;