diff --git a/README.md b/README.md index d76d3d267..316000a52 100644 --- a/README.md +++ b/README.md @@ -814,7 +814,11 @@ async def main(): The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers: ```python -from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.auth import ( + OAuthClientProvider, + TokenExchangeProvider, + TokenStorage, +) from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -851,6 +855,24 @@ async def main(): callback_handler=lambda: ("auth_code", None), ) + # For machine-to-machine scenarios, use ClientCredentialsProvider + # instead of OAuthClientProvider. + + # If you already have a user token from another provider, you can + # exchange it for an MCP token using the token_exchange grant + # implemented by TokenExchangeProvider. + token_exchange_auth = TokenExchangeProvider( + server_url="https://api.example.com", + client_metadata=OAuthClientMetadata( + client_name="My Client", + redirect_uris=["http://localhost:3000/callback"], + grant_types=["client_credentials", "token_exchange"], + response_types=["code"], + ), + storage=CustomTokenStorage(), + subject_token_supplier=lambda: "user_token", + ) + # Use with streamable HTTP client async with streamablehttp_client( "https://api.example.com/mcp", auth=oauth_auth diff --git a/docs/api.md b/docs/api.md index 3f696af54..3291f5c01 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1 +1,5 @@ +The Python SDK exposes the entire `mcp` package for use in your own projects. +It includes an OAuth server implementation with support for the RFC 8693 +`token_exchange` grant type. + ::: mcp diff --git a/docs/index.md b/docs/index.md index 42ad9ca0c..dc0ffea32 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,3 +3,7 @@ This is the MCP Server implementation in Python. It only contains the [API Reference](api.md) for the time being. + +The built-in OAuth server supports the RFC 8693 `token_exchange` grant type, +allowing clients to exchange user tokens from external providers for MCP +access tokens. diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 6e16f8b9d..fd5ffdd24 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -238,6 +238,36 @@ async def exchange_refresh_token( """Exchange refresh token""" raise NotImplementedError("Not supported") + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + """Exchange an external token for an MCP access token.""" + raise NotImplementedError("Token exchange is not supported") + + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + """Exchange client credentials for an access token.""" + token = f"mcp_{secrets.token_hex(32)}" + self.tokens[token] = AccessToken( + token=token, + client_id=client.client_id, + scopes=scopes, + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scopes), + ) + async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: """Revoke a token.""" if token in self.tokens: diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 4e777d600..b3a9e6bb0 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -17,7 +17,6 @@ import anyio import httpx -from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, @@ -49,6 +48,54 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None ... +def _get_authorization_base_url(server_url: str) -> str: + """ + Return the authorization base URL for ``server_url``. + + Per MCP spec 2.3.2, the path component must be discarded so that + ``https://api.example.com/v1/mcp`` becomes ``https://api.example.com``. + """ + from urllib.parse import urlparse, urlunparse + + parsed = urlparse(server_url) + # Remove path component + return urlunparse((parsed.scheme, parsed.netloc, "", "", "", "")) + + +async def _discover_oauth_metadata(server_url: str) -> OAuthMetadata | None: + """ + Discover OAuth metadata from the server's well-known endpoint. + """ + + # Extract base URL per MCP spec + auth_base_url = _get_authorization_base_url(server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server") + headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION} + + async with httpx.AsyncClient() as client: + try: + response = await client.get(url, headers=headers) + if response.status_code == 404: + return None + response.raise_for_status() + metadata_json = response.json() + logger.debug(f"OAuth metadata discovered: {metadata_json}") + return OAuthMetadata.model_validate(metadata_json) + except Exception: + # Retry without MCP header for CORS compatibility + try: + response = await client.get(url) + if response.status_code == 404: + return None + response.raise_for_status() + metadata_json = response.json() + logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}") + return OAuthMetadata.model_validate(metadata_json) + except Exception: + logger.exception("Failed to discover OAuth metadata") + return None + + class OAuthClientProvider(httpx.Auth): """ Authentication for httpx using anyio. @@ -108,50 +155,6 @@ def _generate_code_challenge(self, code_verifier: str) -> str: digest = hashlib.sha256(code_verifier.encode()).digest() return base64.urlsafe_b64encode(digest).decode().rstrip("=") - def _get_authorization_base_url(self, server_url: str) -> str: - """ - Extract base URL by removing path component. - - Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com - """ - from urllib.parse import urlparse, urlunparse - - parsed = urlparse(server_url) - # Remove path component - return urlunparse((parsed.scheme, parsed.netloc, "", "", "", "")) - - async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None: - """ - Discover OAuth metadata from server's well-known endpoint. - """ - # Extract base URL per MCP spec - auth_base_url = self._get_authorization_base_url(server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server") - headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} - - async with httpx.AsyncClient() as client: - try: - response = await client.get(url, headers=headers) - if response.status_code == 404: - return None - response.raise_for_status() - metadata_json = response.json() - logger.debug(f"OAuth metadata discovered: {metadata_json}") - return OAuthMetadata.model_validate(metadata_json) - except Exception: - # Retry without MCP header for CORS compatibility - try: - response = await client.get(url) - if response.status_code == 404: - return None - response.raise_for_status() - metadata_json = response.json() - logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}") - return OAuthMetadata.model_validate(metadata_json) - except Exception: - logger.exception("Failed to discover OAuth metadata") - return None - async def _register_oauth_client( self, server_url: str, @@ -162,13 +165,13 @@ async def _register_oauth_client( Register OAuth client with server. """ if not metadata: - metadata = await self._discover_oauth_metadata(server_url) + metadata = await _discover_oauth_metadata(server_url) if metadata and metadata.registration_endpoint: registration_url = str(metadata.registration_endpoint) else: # Use fallback registration endpoint - auth_base_url = self._get_authorization_base_url(server_url) + auth_base_url = _get_authorization_base_url(server_url) registration_url = urljoin(auth_base_url, "/register") # Handle default scope @@ -303,7 +306,7 @@ async def _perform_oauth_flow(self) -> None: # Discover OAuth metadata if not self._metadata: - self._metadata = await self._discover_oauth_metadata(self.server_url) + self._metadata = await _discover_oauth_metadata(self.server_url) # Ensure client registration client_info = await self._get_or_register_client() @@ -317,7 +320,7 @@ async def _perform_oauth_flow(self) -> None: auth_url_base = str(self._metadata.authorization_endpoint) else: # Use fallback authorization endpoint - auth_base_url = self._get_authorization_base_url(self.server_url) + auth_base_url = _get_authorization_base_url(self.server_url) auth_url_base = urljoin(auth_base_url, "/authorize") # Build authorization URL @@ -362,7 +365,7 @@ async def _exchange_code_for_token(self, auth_code: str, client_info: OAuthClien token_url = str(self._metadata.token_endpoint) else: # Use fallback token endpoint - auth_base_url = self._get_authorization_base_url(self.server_url) + auth_base_url = _get_authorization_base_url(self.server_url) token_url = urljoin(auth_base_url, "/token") token_data = { @@ -422,7 +425,7 @@ async def _refresh_access_token(self) -> bool: token_url = str(self._metadata.token_endpoint) else: # Use fallback token endpoint - auth_base_url = self._get_authorization_base_url(self.server_url) + auth_base_url = _get_authorization_base_url(self.server_url) token_url = urljoin(auth_base_url, "/token") refresh_data = { @@ -468,3 +471,245 @@ async def _refresh_access_token(self) -> bool: except Exception: logger.exception("Token refresh failed") return False + + +class ClientCredentialsProvider(httpx.Auth): + """HTTPX auth using the OAuth2 client credentials grant.""" + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + timeout: float = 300.0, + ): + self.server_url = server_url + self.client_metadata = client_metadata + self.storage = storage + self.timeout = timeout + + self._current_tokens: OAuthToken | None = None + self._metadata: OAuthMetadata | None = None + self._client_info: OAuthClientInformationFull | None = None + self._token_expiry_time: float | None = None + + self._token_lock = anyio.Lock() + + async def _register_oauth_client( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + metadata: OAuthMetadata | None = None, + ) -> OAuthClientInformationFull: + if not metadata: + metadata = await _discover_oauth_metadata(server_url) + + if metadata and metadata.registration_endpoint: + registration_url = str(metadata.registration_endpoint) + else: + auth_base_url = _get_authorization_base_url(server_url) + registration_url = urljoin(auth_base_url, "/register") + + if client_metadata.scope is None and metadata and metadata.scopes_supported is not None: + client_metadata.scope = " ".join(metadata.scopes_supported) + + registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) + + async with httpx.AsyncClient() as client: + response = await client.post( + registration_url, + json=registration_data, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code not in (200, 201): + raise httpx.HTTPStatusError( + f"Registration failed: {response.status_code}", + request=response.request, + response=response, + ) + + return OAuthClientInformationFull.model_validate(response.json()) + + def _has_valid_token(self) -> bool: + if not self._current_tokens or not self._current_tokens.access_token: + return False + + if self._token_expiry_time and time.time() > self._token_expiry_time: + return False + return True + + async def _validate_token_scopes(self, token_response: OAuthToken) -> None: + if not token_response.scope: + return + + requested_scopes: set[str] = set() + if self.client_metadata.scope: + requested_scopes = set(self.client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise Exception(f"Server granted unauthorized scopes: {unauthorized_scopes}.") + else: + granted = set(token_response.scope.split()) + logger.debug( + "No explicit scopes requested, accepting server-granted scopes: %s", + granted, + ) + + async def initialize(self) -> None: + self._current_tokens = await self.storage.get_tokens() + self._client_info = await self.storage.get_client_info() + + async def _get_or_register_client(self) -> OAuthClientInformationFull: + if not self._client_info: + self._client_info = await self._register_oauth_client(self.server_url, self.client_metadata, self._metadata) + await self.storage.set_client_info(self._client_info) + return self._client_info + + async def _request_token(self) -> None: + if not self._metadata: + self._metadata = await _discover_oauth_metadata(self.server_url) + + client_info = await self._get_or_register_client() + + if self._metadata and self._metadata.token_endpoint: + token_url = str(self._metadata.token_endpoint) + else: + auth_base_url = _get_authorization_base_url(self.server_url) + token_url = urljoin(auth_base_url, "/token") + + token_data = { + "grant_type": "client_credentials", + "client_id": client_info.client_id, + } + + if client_info.client_secret: + token_data["client_secret"] = client_info.client_secret + + if self.client_metadata.scope: + token_data["scope"] = self.client_metadata.scope + + async with httpx.AsyncClient() as client: + response = await client.post( + token_url, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if response.status_code != 200: + raise Exception(f"Token request failed: {response.status_code} {response.text}") + + token_response = OAuthToken.model_validate(response.json()) + await self._validate_token_scopes(token_response) + + if token_response.expires_in: + self._token_expiry_time = time.time() + token_response.expires_in + else: + self._token_expiry_time = None + + await self.storage.set_tokens(token_response) + self._current_tokens = token_response + + async def ensure_token(self) -> None: + async with self._token_lock: + if self._has_valid_token(): + return + await self._request_token() + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + if not self._has_valid_token(): + await self.initialize() + await self.ensure_token() + + if self._current_tokens and self._current_tokens.access_token: + request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}" + + response = yield request + + if response.status_code == 401: + self._current_tokens = None + + +class TokenExchangeProvider(ClientCredentialsProvider): + """OAuth2 token exchange based on RFC 8693.""" + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + subject_token_supplier: Callable[[], Awaitable[str]], + subject_token_type: str = "access_token", + actor_token_supplier: Callable[[], Awaitable[str]] | None = None, + actor_token_type: str | None = None, + audience: str | None = None, + resource: str | None = None, + timeout: float = 300.0, + ): + super().__init__(server_url, client_metadata, storage, timeout) + self.subject_token_supplier = subject_token_supplier + self.subject_token_type = subject_token_type + self.actor_token_supplier = actor_token_supplier + self.actor_token_type = actor_token_type + self.audience = audience + self.resource = resource + + async def _request_token(self) -> None: + if not self._metadata: + self._metadata = await _discover_oauth_metadata(self.server_url) + + client_info = await self._get_or_register_client() + + if self._metadata and self._metadata.token_endpoint: + token_url = str(self._metadata.token_endpoint) + else: + auth_base_url = _get_authorization_base_url(self.server_url) + token_url = urljoin(auth_base_url, "/token") + + subject_token = await self.subject_token_supplier() + actor_token = await self.actor_token_supplier() if self.actor_token_supplier else None + + token_data = { + "grant_type": "token_exchange", + "client_id": client_info.client_id, + "subject_token": subject_token, + "subject_token_type": self.subject_token_type, + } + + if client_info.client_secret: + token_data["client_secret"] = client_info.client_secret + + if actor_token: + token_data["actor_token"] = actor_token + if self.actor_token_type: + token_data["actor_token_type"] = self.actor_token_type + if self.audience: + token_data["audience"] = self.audience + if self.resource: + token_data["resource"] = self.resource + if self.client_metadata.scope: + token_data["scope"] = self.client_metadata.scope + + async with httpx.AsyncClient() as client: + response = await client.post( + token_url, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if response.status_code != 200: + raise Exception(f"Token request failed: {response.status_code} {response.text}") + + token_response = OAuthToken.model_validate(response.json()) + await self._validate_token_scopes(token_response) + + if token_response.expires_in: + self._token_expiry_time = time.time() + token_response.expires_in + else: + self._token_expiry_time = None + + await self.storage.set_tokens(token_response) + self._current_tokens = token_response diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 39ac34d8a..4d27d2931 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -173,9 +173,11 @@ async def _handle_sse_event( session_message = SessionMessage(message) await read_stream_writer.send(session_message) - # Call resumption token callback if we have an ID - if sse.id and resumption_callback: - await resumption_callback(sse.id) + # Call resumption token callback if we have an ID. Only update + # the resumption token on notifications to avoid overwriting it + # with the token from the final response. + if sse.id and resumption_callback and not isinstance(message.root, JSONRPCResponse | JSONRPCError): + await resumption_callback(sse.id.strip()) # If this is a response or error return True indicating completion # Otherwise, return False to continue listening @@ -221,7 +223,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" headers = self._prepare_request_headers(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: - headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token.strip() else: raise ResumptionError("Resumption request requires a resumption token") diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 61e403aca..b211e238f 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -68,11 +68,22 @@ async def handle(self, request: Request) -> Response: ), status_code=400, ) - if set(client_metadata.grant_types) != {"authorization_code", "refresh_token"}: + grant_types_set: set[str] = set(client_metadata.grant_types) + valid_sets = [ + {"authorization_code", "refresh_token"}, + {"client_credentials"}, + {"token_exchange"}, + {"client_credentials", "token_exchange"}, + ] + + if grant_types_set not in valid_sets: return PydanticJSONResponse( content=RegistrationErrorResponse( error="invalid_client_metadata", - error_description="grant_types must be authorization_code " "and refresh_token", + error_description=( + "grant_types must be authorization_code and refresh_token " + "or client_credentials or token exchange or client_credentials and token_exchange" + ), ), status_code=400, ) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index d73455200..3ade11452 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -36,16 +36,39 @@ class RefreshTokenRequest(BaseModel): client_secret: str | None = None +class ClientCredentialsRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + grant_type: Literal["client_credentials"] + scope: str | None = Field(None, description="Optional scope parameter") + client_id: str + client_secret: str | None = None + + +class TokenExchangeRequest(BaseModel): + """RFC 8693 token exchange request.""" + + grant_type: Literal["token_exchange"] + subject_token: str = Field(..., description="Token to exchange") + subject_token_type: str = Field(..., description="Type of the subject token") + actor_token: str | None = Field(None, description="Optional actor token") + actor_token_type: str | None = Field(None, description="Type of the actor token if provided") + resource: str | None = None + audience: str | None = None + scope: str | None = None + client_id: str + client_secret: str | None = None + + class TokenRequest( RootModel[ Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest, Field(discriminator="grant_type"), ] ] ): root: Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest, Field(discriminator="grant_type"), ] @@ -183,10 +206,49 @@ async def handle(self, request: Request): ) ) + case ClientCredentialsRequest(): + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else client_info.scope.split(" ") + if client_info.scope + else [] + ) + try: + tokens = await self.provider.exchange_client_credentials(client_info, scopes) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + + case TokenExchangeRequest(): + scopes = token_request.scope.split(" ") if token_request.scope else [] + try: + tokens = await self.provider.exchange_token( + client_info, + token_request.subject_token, + token_request.subject_token_type, + token_request.actor_token, + token_request.actor_token_type, + scopes, + token_request.audience, + token_request.resource, + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + case RefreshTokenRequest(): refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) if refresh_token is None or refresh_token.client_id != token_request.client_id: - # if token belongs to different client, pretend it doesn't exist + # if token belongs to a different client, pretend it doesn't exist return self.response( TokenErrorResponse( error="invalid_grant", diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index da18d7a71..eb824b6a7 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -77,6 +77,7 @@ class AuthorizeError(Exception): "unauthorized_client", "unsupported_grant_type", "invalid_scope", + "invalid_target", ] @@ -238,6 +239,24 @@ async def exchange_refresh_token( """ ... + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + """Exchange client credentials for an access token.""" + ... + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + """Exchange an external token for an MCP access token.""" + ... + async def load_access_token(self, token: str) -> AccessTokenT | None: """ Loads an access token by its token. diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 8647334e0..58a5d2093 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -159,7 +159,12 @@ def build_metadata( scopes_supported=client_registration_options.valid_scopes, response_types_supported=["code"], response_modes_supported=None, - grant_types_supported=["authorization_code", "refresh_token"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ], token_endpoint_auth_methods_supported=["client_secret_post"], token_endpoint_auth_signing_alg_values_supported=None, service_documentation=service_documentation_url, diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 4d2d57221..fb862f248 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -13,6 +13,7 @@ class OAuthToken(BaseModel): expires_in: int | None = None scope: str | None = None refresh_token: str | None = None + issued_token_type: str | None = None @field_validator("token_type", mode="before") @classmethod @@ -46,8 +47,15 @@ class OAuthClientMetadata(BaseModel): # client_secret_post; # ie: we do not support client_secret_basic token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" - # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: list[Literal["authorization_code", "refresh_token"]] = [ + # grant_types: this implementation supports authorization_code, refresh_token, client_credentials, & token_exchange + grant_types: list[ + Literal[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ] + ] = [ "authorization_code", "refresh_token", ] @@ -115,8 +123,18 @@ class OAuthMetadata(BaseModel): scopes_supported: list[str] | None = None response_types_supported: list[str] = ["code"] response_modes_supported: list[Literal["query", "fragment"]] | None = None - grant_types_supported: list[str] | None = None - token_endpoint_auth_methods_supported: list[str] | None = None + grant_types_supported: ( + list[ + Literal[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ] + ] + | None + ) = None + token_endpoint_auth_methods_supported: list[Literal["none", "client_secret_post"]] | None = None token_endpoint_auth_signing_alg_values_supported: None = None service_documentation: AnyHttpUrl | None = None ui_locales_supported: list[str] | None = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 294986acb..9eba940ad 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -312,7 +312,10 @@ async def send_notification( message=JSONRPCMessage(jsonrpc_notification), metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) - await self._write_stream.send(session_message) + try: + await self._write_stream.send(session_message) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logging.debug("Discarding notification due to closed stream") async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): @@ -368,7 +371,9 @@ async def _receive_loop(self) -> None: data="", ), ) + session_message = SessionMessage(message=JSONRPCMessage(error_response)) + await self._write_stream.send(session_message) elif isinstance(message.message.root, JSONRPCNotification): @@ -398,16 +403,14 @@ async def _receive_loop(self) -> None: await self._handle_incoming(notification) except Exception as e: # For other validation errors, log and continue - logging.warning( - f"Failed to validate notification: {e}. " f"Message was: {message.message.root}" - ) + logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}") else: # Response or error stream = self._response_streams.pop(message.message.root.id, None) if stream: await stream.send(message.message.root) else: await self._handle_incoming( - RuntimeError("Received response with an unknown " f"request ID: {message}") + RuntimeError(f"Received response with an unknown request ID: {message}") ) # after the read stream is closed, we need to send errors diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index de4eb70af..f19183399 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2,6 +2,7 @@ Tests for OAuth client authentication implementation. """ +import asyncio import base64 import hashlib import time @@ -10,10 +11,15 @@ import httpx import pytest -from inline_snapshot import snapshot from pydantic import AnyHttpUrl, AnyUrl -from mcp.client.auth import OAuthClientProvider +from mcp.client.auth import ( + ClientCredentialsProvider, + OAuthClientProvider, + TokenExchangeProvider, + _discover_oauth_metadata, + _get_authorization_base_url, +) from mcp.server.auth.routes import build_metadata from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions from mcp.shared.auth import ( @@ -60,6 +66,18 @@ def client_metadata(): ) +@pytest.fixture +def client_credentials_metadata(): + return OAuthClientMetadata( + redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + client_name="CC Client", + grant_types=["client_credentials"], + response_types=["code"], + scope="read write", + token_endpoint_auth_method="client_secret_post", + ) + + @pytest.fixture def oauth_metadata(): return OAuthMetadata( @@ -69,7 +87,12 @@ def oauth_metadata(): registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), scopes_supported=["read", "write", "admin"], response_types_supported=["code"], - grant_types_supported=["authorization_code", "refresh_token"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ], code_challenge_methods_supported=["S256"], ) @@ -115,6 +138,25 @@ async def mock_callback_handler() -> tuple[str, str | None]: ) +@pytest.fixture +async def client_credentials_provider(client_credentials_metadata, mock_storage): + return ClientCredentialsProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_credentials_metadata, + storage=mock_storage, + ) + + +@pytest.fixture +async def token_exchange_provider(client_credentials_metadata, mock_storage): + return TokenExchangeProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_credentials_metadata, + storage=mock_storage, + subject_token_supplier=lambda: asyncio.sleep(0, result="user_token"), + ) + + class TestOAuthClientProvider: """Test OAuth client provider functionality.""" @@ -126,7 +168,8 @@ async def test_init(self, oauth_provider, client_metadata, mock_storage): assert oauth_provider.storage == mock_storage assert oauth_provider.timeout == 300.0 - def test_generate_code_verifier(self, oauth_provider): + @pytest.mark.anyio + async def test_generate_code_verifier(self, oauth_provider): """Test PKCE code verifier generation.""" verifier = oauth_provider._generate_code_verifier() @@ -162,16 +205,13 @@ async def test_generate_code_challenge(self, oauth_provider): async def test_get_authorization_base_url(self, oauth_provider): """Test authorization base URL extraction.""" # Test with path - assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com" + assert _get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com" # Test with no path - assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com" + assert _get_authorization_base_url("https://api.example.com") == "https://api.example.com" # Test with port - assert ( - oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp") - == "https://api.example.com:8080" - ) + assert _get_authorization_base_url("https://api.example.com:8080/path/to/mcp") == "https://api.example.com:8080" @pytest.mark.anyio async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata): @@ -187,7 +227,7 @@ async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metad mock_response.json.return_value = metadata_response mock_client.get.return_value = mock_response - result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") + result = await _discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is not None assert result.authorization_endpoint == oauth_metadata.authorization_endpoint @@ -209,7 +249,7 @@ async def test_discover_oauth_metadata_not_found(self, oauth_provider): mock_response.status_code = 404 mock_client.get.return_value = mock_response - result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") + result = await _discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is None @@ -232,7 +272,7 @@ async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth mock_response_success, # Second call succeeds ] - result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") + result = await _discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is not None assert mock_client.get.call_count == 2 @@ -280,7 +320,7 @@ async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oau mock_client.post.return_value = mock_response # Mock metadata discovery to return None (fallback) - with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): + with patch("mcp.client.auth._discover_oauth_metadata", return_value=None): result = await oauth_provider._register_oauth_client( "https://api.example.com/v1/mcp", oauth_provider.client_metadata, @@ -307,7 +347,7 @@ async def test_register_oauth_client_failure(self, oauth_provider): mock_client.post.return_value = mock_response # Mock metadata discovery to return None (fallback) - with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): + with patch("mcp.client.auth._discover_oauth_metadata", return_value=None): with pytest.raises(httpx.HTTPStatusError): await oauth_provider._register_oauth_client( "https://api.example.com/v1/mcp", @@ -880,18 +920,100 @@ def test_build_metadata( revocation_options=RevocationOptions(enabled=True), ) - assert metadata == snapshot( - OAuthMetadata( - issuer=AnyHttpUrl(issuer_url), - authorization_endpoint=AnyHttpUrl(authorization_endpoint), - token_endpoint=AnyHttpUrl(token_endpoint), - registration_endpoint=AnyHttpUrl(registration_endpoint), - scopes_supported=["read", "write", "admin"], - grant_types_supported=["authorization_code", "refresh_token"], - token_endpoint_auth_methods_supported=["client_secret_post"], - service_documentation=AnyHttpUrl(service_documentation_url), - revocation_endpoint=AnyHttpUrl(revocation_endpoint), - revocation_endpoint_auth_methods_supported=["client_secret_post"], - code_challenge_methods_supported=["S256"], - ) + expected = OAuthMetadata( + issuer=AnyHttpUrl(issuer_url), + authorization_endpoint=AnyHttpUrl(authorization_endpoint), + token_endpoint=AnyHttpUrl(token_endpoint), + registration_endpoint=AnyHttpUrl(registration_endpoint), + scopes_supported=["read", "write", "admin"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ], + token_endpoint_auth_methods_supported=["client_secret_post"], + service_documentation=AnyHttpUrl(service_documentation_url), + revocation_endpoint=AnyHttpUrl(revocation_endpoint), + revocation_endpoint_auth_methods_supported=["client_secret_post"], + code_challenge_methods_supported=["S256"], ) + + assert metadata == expected + + +class TestClientCredentialsProvider: + @pytest.mark.anyio + async def test_request_token_success( + self, + client_credentials_provider, + oauth_metadata, + oauth_client_info, + oauth_token, + ): + client_credentials_provider._metadata = oauth_metadata + client_credentials_provider._client_info = oauth_client_info + + token_json = oauth_token.model_dump(by_alias=True, mode="json") + token_json.pop("refresh_token", None) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = token_json + mock_client.post.return_value = mock_response + + await client_credentials_provider.ensure_token() + + mock_client.post.assert_called_once() + assert client_credentials_provider._current_tokens.access_token == oauth_token.access_token + + @pytest.mark.anyio + async def test_async_auth_flow(self, client_credentials_provider, oauth_token): + client_credentials_provider._current_tokens = oauth_token + client_credentials_provider._token_expiry_time = time.time() + 3600 + + request = httpx.Request("GET", "https://api.example.com/data") + mock_response = Mock() + mock_response.status_code = 200 + + auth_flow = client_credentials_provider.async_auth_flow(request) + updated_request = await auth_flow.__anext__() + assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}" + try: + await auth_flow.asend(mock_response) + except StopAsyncIteration: + pass + + +class TestTokenExchangeProvider: + @pytest.mark.anyio + async def test_request_token_success( + self, + token_exchange_provider, + oauth_metadata, + oauth_client_info, + oauth_token, + ): + token_exchange_provider._metadata = oauth_metadata + token_exchange_provider._client_info = oauth_client_info + + token_json = oauth_token.model_dump(by_alias=True, mode="json") + token_json.pop("refresh_token", None) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = token_json + mock_client.post.return_value = mock_response + + await token_exchange_provider.ensure_token() + + mock_client.post.assert_called_once() + assert token_exchange_provider._current_tokens.access_token == oauth_token.access_token diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 97edb651e..209bafd99 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -74,6 +74,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): method="tools/call", # params=None # Missing required params ) + another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request)) await read_send_stream.send(another_request_message) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 5db5d58c2..cd55d3a4c 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -20,6 +20,7 @@ AuthorizationParams, OAuthAuthorizationServerProvider, RefreshToken, + TokenError, construct_redirect_uri, ) from mcp.server.auth.routes import ( @@ -160,6 +161,49 @@ async def exchange_refresh_token( refresh_token=new_refresh_token, ) + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + access_token = f"access_{secrets.token_hex(32)}" + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scopes, + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scopes), + ) + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + if subject_token == "bad_token": + raise TokenError("invalid_grant", "invalid subject token") + + access_token = f"exchanged_{secrets.token_hex(32)}" + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scope or ["read"], + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scope or ["read"]), + ) + async def load_access_token(self, token: str) -> AccessToken | None: token_info = self.tokens.get(token) @@ -354,6 +398,8 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", + "client_credentials", + "token_exchange", ] assert metadata["service_documentation"] == "https://docs.example.com/" @@ -928,7 +974,28 @@ async def test_client_registration_invalid_grant_type(self, test_client: httpx.A error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token" + assert error_data["error_description"] == ( + "grant_types must be authorization_code and refresh_token " + "or client_credentials or token exchange or " + "client_credentials and token_exchange" + ) + + @pytest.mark.anyio + async def test_client_registration_client_credentials(self, test_client: httpx.AsyncClient): + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "CC Client", + "grant_types": ["client_credentials"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + + assert response.status_code == 201, response.content + client_info = response.json() + assert client_info["grant_types"] == ["client_credentials"] class TestAuthorizeEndpointErrors: @@ -1201,3 +1268,102 @@ async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, reg # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["client_credentials"]}], + indirect=True, + ) + async def test_client_credentials_token(self, test_client: httpx.AsyncClient, registered_client): + response = await test_client.post( + "/token", + data={ + "grant_type": "client_credentials", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "scope": "read write", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + @pytest.mark.anyio + async def test_metadata_includes_token_exchange(self, test_client: httpx.AsyncClient): + response = await test_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + metadata = response.json() + assert "token_exchange" in metadata["grant_types_supported"] + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["token_exchange"]}], + indirect=True, + ) + async def test_token_exchange_success(self, test_client: httpx.AsyncClient, registered_client): + response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "good_token", + "subject_token_type": "access_token", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["token_exchange"]}], + indirect=True, + ) + async def test_token_exchange_invalid_subject(self, test_client: httpx.AsyncClient, registered_client): + response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "bad_token", + "subject_token_type": "access_token", + }, + ) + assert response.status_code == 400 + data = response.json() + assert data["error"] == "invalid_grant" + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["client_credentials", "token_exchange"]}], + indirect=True, + ) + async def test_client_credentials_and_token_exchange(self, test_client: httpx.AsyncClient, registered_client): + cc_response = await test_client.post( + "/token", + data={ + "grant_type": "client_credentials", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "scope": "read write", + }, + ) + assert cc_response.status_code == 200 + + te_response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "good_token", + "subject_token_type": "access_token", + }, + ) + assert te_response.status_code == 200 diff --git a/tests/server/fastmcp/resources/test_file_resources.py b/tests/server/fastmcp/resources/test_file_resources.py index ec3c85d8d..1ff9a3cb5 100644 --- a/tests/server/fastmcp/resources/test_file_resources.py +++ b/tests/server/fastmcp/resources/test_file_resources.py @@ -100,18 +100,21 @@ async def test_missing_file_error(self, temp_file: Path): with pytest.raises(ValueError, match="Error reading file"): await resource.read() - @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") - @pytest.mark.anyio - async def test_permission_error(self, temp_file: Path): - """Test reading a file without permissions.""" - temp_file.chmod(0o000) # Remove all permissions - try: - resource = FileResource( - uri=FileUrl(temp_file.as_uri()), - name="test", - path=temp_file, - ) - with pytest.raises(ValueError, match="Error reading file"): - await resource.read() - finally: - temp_file.chmod(0o644) # Restore permissions + +@pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") +@pytest.mark.anyio +async def test_permission_error(temp_file: Path): + """Test reading a file without permissions.""" + if os.geteuid() == 0: + pytest.skip("Permission test not reliable when running as root") + temp_file.chmod(0o000) # Remove all permissions + try: + resource = FileResource( + uri=FileUrl(temp_file.as_uri()), + name="test", + path=temp_file, + ) + with pytest.raises(ValueError, match="Error reading file"): + await resource.read() + finally: + temp_file.chmod(0o644) # Restore permissions