diff --git a/README.md b/README.md index e0a4dfb..f7e515a 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,37 @@ async def get_server_time() -> str: return datetime.now().isoformat() ``` +By default, all API endpoints are turned into an MCP Tool. However, it is possible to explicitly declare what type of MCP object you want to turn an API endpoint into. Current this only works for `mcp_tool`, but will be extended to include `mcp_resource`, `mcp_prompt`, and `mcp_sample` in the future. + +```python +from fastapi import FastAPI +from fastapi_mcp import add_mcp_server + +app = FastAPI() + +@app.get("/items/{item_id}", response_model=Item, tags=["items", "mcp_tool"]) +async def get_item(item_id: int): + """Get an item by ID.""" + return {"item_id": item_id} + +mcp_server = add_mcp_server( + app, # Your FastAPI app + mount_path="/mcp", # Where to mount the MCP server + name="My API MCP", # Name for the MCP server +) +``` + +In some cases you may want to exclude certain endpoints from being turned into MCP objects. This can be done by setting the `exclude_untagged` parameter to `True`: + +```python +mcp_server = add_mcp_server( + app, # Your FastAPI app + mount_path="/mcp", # Where to mount the MCP server + name="My API MCP", # Name for the MCP server + exclude_untagged=True, # Exclude all endpoints that don't have the "mcp_tool" tag +) +``` + ## Examples See the [examples](examples) directory for complete examples. diff --git a/fastapi_mcp/http_tools.py b/fastapi_mcp/http_tools.py index 3102257..6588d30 100644 --- a/fastapi_mcp/http_tools.py +++ b/fastapi_mcp/http_tools.py @@ -6,6 +6,7 @@ import json import logging +from enum import Enum from typing import Any, Dict, List, Optional import httpx @@ -17,6 +18,13 @@ logger = logging.getLogger("fastapi_mcp") +class MCPType(Enum): + TOOL = "mcp_tool" + RESOURCE = "mcp_resource" + SAMPLE = "mcp_sample" + PROMPT = "mcp_prompt" + + def resolve_schema_references(schema: Dict[str, Any], openapi_schema: Dict[str, Any]) -> Dict[str, Any]: """ Resolve schema references in OpenAPI schemas. @@ -106,6 +114,7 @@ def create_mcp_tools_from_openapi( base_url: Optional[str] = None, describe_all_responses: bool = False, describe_full_response_schema: bool = False, + exclude_untagged: bool = False, ) -> None: """ Create MCP tools from a FastAPI app's OpenAPI schema. @@ -116,6 +125,7 @@ def create_mcp_tools_from_openapi( base_url: Base URL for API requests (defaults to http://localhost:$PORT) describe_all_responses: Whether to include all possible response schemas in tool descriptions describe_full_response_schema: Whether to include full response schema in tool descriptions + exclude_untagged: Whether to exclude tools that do not have MCP type tags (mcp_tool, mcp_resource, mcp_sample, mcp_prompt) """ # Get OpenAPI schema from FastAPI app openapi_schema = get_openapi( @@ -144,6 +154,12 @@ def create_mcp_tools_from_openapi( base_url = base_url[:-1] # Process each path in the OpenAPI schema + mcp_types = { + MCPType.TOOL.value, + MCPType.RESOURCE.value, + MCPType.SAMPLE.value, + MCPType.PROMPT.value, + } for path, path_item in openapi_schema.get("paths", {}).items(): for method, operation in path_item.items(): # Skip non-HTTP methods @@ -155,8 +171,20 @@ def create_mcp_tools_from_openapi( if not operation_id: continue + # If we do not create tools by default, + # Skip registering tools unless they are explicitly allowed + tags = mcp_types.intersection(set(operation.get("tags", []))) + tag = MCPType(tags.pop()) if len(tags) >= 1 else None + if len(tags) > 1: + logger.warning(f"Operation {operation_id} has multiple MCP types. Using {tag}, but found {tags}") + if tag is None: + if exclude_untagged: + continue + else: + tag = MCPType.TOOL + # Create MCP tool for this operation - create_http_tool( + create_http_mcp_call( mcp_server=mcp_server, base_url=base_url, path=path, @@ -170,10 +198,11 @@ def create_mcp_tools_from_openapi( openapi_schema=openapi_schema, describe_all_responses=describe_all_responses, describe_full_response_schema=describe_full_response_schema, + mcp_type=tag, ) -def create_http_tool( +def create_http_mcp_call( mcp_server: FastMCP, base_url: str, path: str, @@ -187,12 +216,13 @@ def create_http_tool( openapi_schema: Dict[str, Any], describe_all_responses: bool, describe_full_response_schema: bool, + mcp_type: MCPType, ) -> None: """ - Create an MCP tool that makes an HTTP request to a FastAPI endpoint. + Create an MCP resource, tool, sample, or prompt that makes an HTTP request to a FastAPI endpoint. Args: - mcp_server: The MCP server to add the tool to + mcp_server: The MCP server to add calls to base_url: Base URL for API requests path: API endpoint path method: HTTP method @@ -203,13 +233,14 @@ def create_http_tool( request_body: OpenAPI request body responses: OpenAPI responses openapi_schema: The full OpenAPI schema - describe_all_responses: Whether to include all possible response schemas in tool descriptions - describe_full_response_schema: Whether to include full response schema in tool descriptions + describe_all_responses: Whether to include all possible response schemas in descriptions + describe_full_response_schema: Whether to include full response schema in descriptions + mcp_type: MCP type. """ - # Build tool description - tool_description = f"{summary}" if summary else f"{method.upper()} {path}" + # Build call description + call_description = f"{summary}" if summary else f"{method.upper()} {path}" if description: - tool_description += f"\n\n{description}" + call_description += f"\n\n{description}" # Add response schema information to description if responses: @@ -300,6 +331,7 @@ def create_http_tool( not example_response and display_schema.get("type") == "array" and items_model_name == "Item" + and mcp_type == MCPType.TOOL ): example_response = [ { @@ -351,7 +383,7 @@ def create_http_tool( response_info += json.dumps(display_schema, indent=2) response_info += "\n```" - tool_description += response_info + call_description += response_info # Organize parameters by type path_params = [] @@ -436,7 +468,7 @@ def create_http_tool( required_props.append(param_name) # Function to dynamically call the API endpoint - async def http_tool_function(kwargs: Dict[str, Any] = Field(default_factory=dict)): + async def http_function(kwargs: Dict[str, Any] = Field(default_factory=dict)): # Prepare URL with path parameters url = f"{base_url}{path}" for param_name, _ in path_params: @@ -480,25 +512,32 @@ async def http_tool_function(kwargs: Dict[str, Any] = Field(default_factory=dict except ValueError: return response.text - # Create a proper input schema for the tool - input_schema = {"type": "object", "properties": properties, "title": f"{operation_id}Arguments"} - - if required_props: - input_schema["required"] = required_props - # Set the function name and docstring - http_tool_function.__name__ = operation_id - http_tool_function.__doc__ = tool_description + http_function.__name__ = operation_id + http_function.__doc__ = call_description + + if mcp_type == MCPType.TOOL: + # Create a proper input schema for the tool + input_schema = { + "type": "object", + "properties": properties, + "title": f"{operation_id}Arguments", + } + + if required_props: + input_schema["required"] = required_props - # Monkey patch the function's schema for MCP tool creation - # TODO: Maybe revise this hacky approach - http_tool_function._input_schema = input_schema # type: ignore + # Monkey patch the function's schema for MCP tool creation + # TODO: Maybe revise this hacky approach + http_function._input_schema = input_schema # type: ignore - # Add tool to the MCP server with the enhanced schema - tool = mcp_server._tool_manager.add_tool(http_tool_function, name=operation_id, description=tool_description) + # Add tool to the MCP server with the enhanced schema + tool = mcp_server._tool_manager.add_tool(http_function, name=operation_id, description=call_description) - # Update the tool's parameters to use our custom schema instead of the auto-generated one - tool.parameters = input_schema + # Update the tool's parameters to use our custom schema instead of the auto-generated one + tool.parameters = input_schema + else: + raise NotImplementedError(f"MCP type {mcp_type} not implemented") def extract_model_examples_from_components( @@ -561,7 +600,15 @@ def generate_example_from_schema(schema: Dict[str, Any], model_name: Optional[st } elif model_name == "HTTPValidationError": # Create a realistic validation error example - return {"detail": [{"loc": ["body", "name"], "msg": "field required", "type": "value_error.missing"}]} + return { + "detail": [ + { + "loc": ["body", "name"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } # Handle different types schema_type = schema.get("type") diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index db65adb..8a3a0eb 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -56,6 +56,7 @@ def mount_mcp_server( base_url: Optional[str] = None, describe_all_responses: bool = False, describe_full_response_schema: bool = False, + exclude_untagged: bool = False, ) -> None: """ Mount an MCP server to a FastAPI app. @@ -68,6 +69,7 @@ def mount_mcp_server( base_url: Base URL for API requests describe_all_responses: Whether to include all possible response schemas in tool descriptions. Recommended to keep False, as the LLM will probably derive if there is an error. describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. Recommended to keep False, as examples are more LLM friendly, and save tokens. + exclude_untagged: Whether to exclude untagged endpoints from being served as tools. """ # Normalize mount path if not mount_path.startswith("/"): @@ -99,6 +101,7 @@ async def handle_mcp_connection(request: Request): base_url, describe_all_responses=describe_all_responses, describe_full_response_schema=describe_full_response_schema, + exclude_untagged=exclude_untagged, ) @@ -112,6 +115,7 @@ def add_mcp_server( base_url: Optional[str] = None, describe_all_responses: bool = False, describe_full_response_schema: bool = False, + exclude_untagged: bool = False, ) -> FastMCP: """ Add an MCP server to a FastAPI app. @@ -142,6 +146,7 @@ def add_mcp_server( base_url, describe_all_responses=describe_all_responses, describe_full_response_schema=describe_full_response_schema, + exclude_untagged=exclude_untagged, ) return mcp_server diff --git a/tests/test_tool_generation.py b/tests/test_tool_generation.py index 25102d1..d351fe5 100644 --- a/tests/test_tool_generation.py +++ b/tests/test_tool_generation.py @@ -50,13 +50,24 @@ async def create_item(item: Item): """ return item + @app.post("/limited/", response_model=Item, tags=["items", "mcp_tool"]) + async def create_limited_item(item: Item): + """ + Create a new item. + + Returns the created item. + """ + return item + return app def test_tool_generation_basic(sample_app): """Test that MCP tools are properly generated with default settings.""" # Create MCP server and register tools - mcp_server = add_mcp_server(sample_app, serve_tools=True, base_url="http://localhost:8000") + mcp_server = add_mcp_server( + sample_app, serve_tools=True, base_url="http://localhost:8000" + ) # Extract tools for inspection tools = mcp_server._tool_manager.list_tools() @@ -76,27 +87,42 @@ def test_tool_generation_basic(sample_app): continue # With describe_all_responses=False by default, description should only include success response code - assert "200" in tool.description, f"Expected success response code in description for {tool.name}" - assert "422" not in tool.description, f"Expected not to see 422 response in tool description for {tool.name}" + assert ( + "200" in tool.description + ), f"Expected success response code in description for {tool.name}" + assert ( + "422" not in tool.description + ), f"Expected not to see 422 response in tool description for {tool.name}" # With describe_full_response_schema=False by default, description should not include the full output schema, only an example - assert "Example Response" in tool.description, f"Expected example response in description for {tool.name}" - assert "Output Schema" not in tool.description, ( - f"Expected not to see output schema in description for {tool.name}" - ) + assert ( + "Example Response" in tool.description + ), f"Expected example response in description for {tool.name}" + assert ( + "Output Schema" not in tool.description + ), f"Expected not to see output schema in description for {tool.name}" # Verify specific parameters are present in the appropriate tools - list_items_tool = next((t for t in tools if t.name == "list_items_items__get"), None) + list_items_tool = next( + (t for t in tools if t.name == "list_items_items__get"), None + ) assert list_items_tool is not None, "list_items tool not found" - assert "skip" in list_items_tool.parameters["properties"], "Expected 'skip' parameter" - assert "limit" in list_items_tool.parameters["properties"], "Expected 'limit' parameter" + assert ( + "skip" in list_items_tool.parameters["properties"] + ), "Expected 'skip' parameter" + assert ( + "limit" in list_items_tool.parameters["properties"] + ), "Expected 'limit' parameter" def test_tool_generation_with_full_schema(sample_app): """Test that MCP tools include full response schema when requested.""" # Create MCP server with full schema for all operations mcp_server = add_mcp_server( - sample_app, serve_tools=True, base_url="http://localhost:8000", describe_full_response_schema=True + sample_app, + serve_tools=True, + base_url="http://localhost:8000", + describe_full_response_schema=True, ) # Extract tools for inspection @@ -110,15 +136,22 @@ def test_tool_generation_with_full_schema(sample_app): description = tool.description # Check that the tool includes information about the Item schema - assert "Item" in description, f"Item schema should be included in the description for {tool.name}" - assert "price" in description, f"Item properties should be included in the description for {tool.name}" + assert ( + "Item" in description + ), f"Item schema should be included in the description for {tool.name}" + assert ( + "price" in description + ), f"Item properties should be included in the description for {tool.name}" def test_tool_generation_with_all_responses(sample_app): """Test that MCP tools include all possible responses when requested.""" # Create MCP server with all response status codes mcp_server = add_mcp_server( - sample_app, serve_tools=True, base_url="http://localhost:8000", describe_all_responses=True + sample_app, + serve_tools=True, + base_url="http://localhost:8000", + describe_all_responses=True, ) # Extract tools for inspection @@ -130,8 +163,12 @@ def test_tool_generation_with_all_responses(sample_app): if tool.name == "handle_mcp_connection_mcp_get": continue - assert "200" in tool.description, f"Expected success response code in description for {tool.name}" - assert "422" in tool.description, f"Expected 422 response code in description for {tool.name}" + assert ( + "200" in tool.description + ), f"Expected success response code in description for {tool.name}" + assert ( + "422" in tool.description + ), f"Expected 422 response code in description for {tool.name}" def test_tool_generation_with_all_responses_and_full_schema(sample_app): @@ -154,15 +191,23 @@ def test_tool_generation_with_all_responses_and_full_schema(sample_app): if tool.name == "handle_mcp_connection_mcp_get": continue - assert "200" in tool.description, f"Expected success response code in description for {tool.name}" - assert "422" in tool.description, f"Expected 422 response code in description for {tool.name}" - assert "Output Schema" in tool.description, f"Expected output schema in description for {tool.name}" + assert ( + "200" in tool.description + ), f"Expected success response code in description for {tool.name}" + assert ( + "422" in tool.description + ), f"Expected 422 response code in description for {tool.name}" + assert ( + "Output Schema" in tool.description + ), f"Expected output schema in description for {tool.name}" def test_custom_tool_addition(sample_app): """Test that custom tools can be added alongside API tools.""" # Create MCP server with API tools - mcp_server = add_mcp_server(sample_app, serve_tools=True, base_url="http://localhost:8000") + mcp_server = add_mcp_server( + sample_app, serve_tools=True, base_url="http://localhost:8000" + ) # Get initial tool count initial_tool_count = len(mcp_server._tool_manager.list_tools()) @@ -177,12 +222,38 @@ async def custom_tool() -> str: tools = mcp_server._tool_manager.list_tools() # Verify we have one more tool than before - assert len(tools) == initial_tool_count + 1, f"Expected {initial_tool_count + 1} tools, got {len(tools)}" + assert ( + len(tools) == initial_tool_count + 1 + ), f"Expected {initial_tool_count + 1} tools, got {len(tools)}" # Find both API tools and custom tools - list_items_tool = next((t for t in tools if t.name == "list_items_items__get"), None) + list_items_tool = next( + (t for t in tools if t.name == "list_items_items__get"), None + ) assert list_items_tool is not None, "API tool (list_items) not found" custom_tool_def = next((t for t in tools if t.name == "custom_tool"), None) assert custom_tool_def is not None, "Custom tool not found" - assert custom_tool_def.description == "A custom tool for testing.", "Custom tool description not preserved" + assert ( + custom_tool_def.description == "A custom tool for testing." + ), "Custom tool description not preserved" + + +def test_tool_addition_with_tags(sample_app): + """Test that custom tools can be added alongside API tools.""" + # Create MCP server with API tools + mcp_server = add_mcp_server( + sample_app, + serve_tools=True, + base_url="http://localhost:8000", + exclude_untagged=True, + ) + + # Get initial tool count + tool_count = len(mcp_server._tool_manager.list_tools()) + + # Make sure there is exactly 1 tool + assert tool_count == 1, f"Expected exactly 1 tool, got {tool_count}" + + # Extract tools for inspection + tools = mcp_server._tool_manager.list_tools()