Skip to content

[Docs/Help/Feature Request]: Getting authentication details in tool calls #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wolf-sigma opened this issue May 3, 2025 · 8 comments
Assignees

Comments

@wolf-sigma
Copy link

wolf-sigma commented May 3, 2025

Disclaimer

I haven't tried using the recent OAuth samples (my source isn't 100% setup to support it yet), but I'm pretty sure this issue applies based on analyzing it. I also did not try implementing a custom transport.

Goal

Get authentication details (specifically the authenticated identity) visible to my tool methods.

What I've done

My source supports API keys - just straight auth, no token exchange or the like. Using the MCP inspector, I've set up the connection with Bearer tokens with the key. I've successfully implemented an Axum middleware that handles the Authentication, returning to the server when it fails (example below).

The problem

Due to the abstractions between Axum and the rest of the MCP framework, there doesn't appear to be a clean way to get any details of the middleware (or anything else from the router for that matter) to the MCP server.

My approach was to assign an Identity object to the MCP toolbox struct and assign it "at some point". But this was just my initial thought. I don't care as long as I can get it from the tool calls.

Ex:

#[derive(Clone)]
pub struct Counter {
    counter: Arc<Mutex<i32>>,
    identity: Identity
}

#[tool(tool_box)]
impl Counter {
    #[allow(dead_code)]
    pub fn new() -> Self {
       let identity =  get_identity(); // <- unclear how!
       Self {
            counter: Arc::new(Mutex::new(0)),
            identity: identity // Either set it here or make it optional and assign it after construction
        }
    }

    #[tool(description = "Increment the counter by 1")]
    async fn increment(&self) -> Result<CallToolResult, McpError> {
        let mut counter = self.counter.lock().await;
        *counter += 1;
    
        do_something_with_identity(self.identity);        

        Ok(CallToolResult::success(vec![Content::text(
            counter.to_string(),
        )]))
    }
}

Since there's no obvious way when constructing the service, I tried assigning it during initialization. For the calls in ServerHandler that have it, the context parameter doesn't expose - publicly or privately - info about the Axum Router. I tried using the Extensions from the SDK (those available in context for some of the calls) but I couldn't find a way to set them from the Axum "side".

The question

Am I missing something obvious?

If not, I think this would be a critical feature. I'm happy to implement it if there's consensus on how (or if not, I can propose a solution).

Example code that sets Axum extensions

I can provide a full example if it helps.

pub async fn auth(mut req: Request, next: Next) -> Result<Response, StatusCode> {
    // This middlware works as expected

    let auth_header = req
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok())
        .ok_or(StatusCode::UNAUTHORIZED)?;
   
   // Actual authentication happens here
   if let Some(identity) = authorize_current_user(auth_header).await {
      // Set an Axum extension for the request 
      req.extensions_mut().insert(identity.clone());

      Ok(next.run(req).await)
    } else {
        Err(StatusCode::UNAUTHORIZED)
    }
}

async fn server_entrypoint() -> Result<()> {
    let mcp_config = MCPServerConfig::default();
    tracing::info!("Config for MCP Server: {:?}", mcp_config);

    let bind_address = format!("{}:{}", mcp_config.mcp_bind_address, mcp_config.mcp_bind_port);

    let config = SseServerConfig {
        bind: bind_address.parse()?,
        sse_path: "/sse".to_string(),
        post_path: "/message".to_string(),
        ct: tokio_util::sync::CancellationToken::new(),
        sse_keep_alive: None
    };
    
    let (sse_server, router) = SseServer::new(config);
    
    let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?;

    let ct = sse_server.config.ct.child_token();
    
    // Add auth middleware
    let router = router
        .route_layer(middleware::from_fn(auth));
    
    let server = axum::serve(listener, router).with_graceful_shutdown(async move {
        ct.cancelled().await;
        tracing::info!("SSE server cancelled");
    });
    
    tokio::spawn(async move {
        if let Err(e) = server.await {
            tracing::error!(error = %e, "sse server shutdown with error");
        }
    });
    
    let ct = sse_server.with_service(service::Counter::new);
    
    graceful_shutdown().await;
    ct.cancel();

    Ok(())
}
@wolf-sigma
Copy link
Author

Some related issues/prs:

#61 #56 #84

Here's also an example of me trying to pull the Axum Extensions in the context parameters:

#[tool(tool_box)]
impl ServerHandler for Counter {
    // Override default
    fn ping(
        &self,
        context: RequestContext<RoleServer>,
    ) -> impl Future<Output =std::result::Result<(), McpError>> + Send + '_ {

        let ext = &context.meta.0;
        tracing::info!("Ping received; Extension {:?} Identity: {:?}", ext, context.extensions.get::<Identity>());

        tracing::info!("Found context {:?}", context);

        std::future::ready(Ok(()))
    }
}

The logs for that call look something like this:

2025-05-03T00:44:57.594154Z DEBUG rmcp::transport::sse_server: new client message session_id="b1217529e050da2682d7e751a4e37a3c" Request(JsonRpcRequest { jsonrpc: JsonRpcVersion2_0, id: Number(0), request: InitializeRequest(Request { method: InitializeResultMethod, params: InitializeRequestParam { protocol_version: ProtocolVersion("2024-11-05"), capabilities: ClientCapabilities { experimental: None, roots: Some(RootsCapabilities { list_changed: Some(true) }), sampling: Some({}) }, client_info: Implementation { name: "mcp-inspector", version: "0.10.2" } }, extensions: Extensions }) })
2025-05-03T00:44:57.594374Z  INFO mcp_server_bin::service::counter: Ping received; Extension {} Identity: None
2025-05-03T00:44:57.594386Z  INFO mcp_server_bin::service::counter: Found request InitializeRequestParam { protocol_version: ProtocolVersion("2024-11-05"), capabilities: ClientCapabilities { experimental: None, roots: Some(RootsCapabilities { list_changed: Some(true) }), sampling: Some({}) }, client_info: Implementation { name: "mcp-inspector", version: "0.10.2" } }
2025-05-03T00:44:57.594395Z  INFO mcp_server_bin::service::counter: Found context RequestContext { ct: CancellationToken { is_cancelled: false }, id: Number(0), meta: Meta({}), extensions: Extensions, peer: PeerSink { tx: Sender { chan: Tx { inner: Chan { tx: Tx { block_tail: 0x13102f600, tail_position: 0 }, semaphore: Semaphore { semaphore: Semaphore { permits: 1024 }, bound: 1024 }, rx_waker: AtomicWaker, tx_count: 2, rx_fields: "..." } } }, is_client: false } }

@4t145
Copy link
Collaborator

4t145 commented May 3, 2025

We didn't clone the extensions from axum into the extensions in rmcp request. So maybe we should clone the extension into mcp request's extension, I will add a patch on this.

update: I'd like extract the axum common part into one file, after axum streamable http transport being merged

@4t145 4t145 self-assigned this May 3, 2025
@wolf-sigma
Copy link
Author

Sounds good - thanks a ton @4t145

@4t145
Copy link
Collaborator

4t145 commented May 9, 2025

#150 is merged, you can check this example https://github.com/modelcontextprotocol/rust-sdk/pull/163/files#diff-e8a7b3352088986e9a3c5b83cfea597a505d62b2881bb6f0b6dbc07ca6518e9d to see if this pr solved your problem

@sahra-karakoc
Copy link

sahra-karakoc commented May 12, 2025

@4t145 Seems that we still don't have access to the headers inside tool calls while using the tool macro, right?

@4t145
Copy link
Collaborator

4t145 commented May 13, 2025

@sahra-karakoc You can get what you want by implementing FromToolCallContextPart, and add it as a argument of the tool function so it can be extracted automatically. Also we need to implement it for Extensions. (I forgot it)

@sahra-karakoc
Copy link

@4t145 lmk if this works #175

@4t145
Copy link
Collaborator

4t145 commented May 19, 2025

@wolf-sigma @sahra-karakoc exposed the extensions in #199, you can get it by extract Extension(http::request::Parts)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants