diff --git a/fastapi_mcp/http_tools.py b/fastapi_mcp/http_tools.py index 3102257..aecfbfa 100644 --- a/fastapi_mcp/http_tools.py +++ b/fastapi_mcp/http_tools.py @@ -17,7 +17,9 @@ logger = logging.getLogger("fastapi_mcp") -def resolve_schema_references(schema: Dict[str, Any], openapi_schema: Dict[str, Any]) -> Dict[str, Any]: +def resolve_schema_references( + schema: Dict[str, Any], openapi_schema: Dict[str, Any], top_schema=None +) -> Dict[str, Any]: """ Resolve schema references in OpenAPI schemas. @@ -31,6 +33,9 @@ def resolve_schema_references(schema: Dict[str, Any], openapi_schema: Dict[str, # Make a copy to avoid modifying the input schema schema = schema.copy() + # Create a a definnition prefix for the schema + def_prefix = "#/$defs/" + # Handle $ref directly in the schema if "$ref" in schema: ref_path = schema["$ref"] @@ -41,18 +46,42 @@ def resolve_schema_references(schema: Dict[str, Any], openapi_schema: Dict[str, if model_name in openapi_schema["components"]["schemas"]: # Replace with the resolved schema ref_schema = openapi_schema["components"]["schemas"][model_name].copy() - # Remove the $ref key and merge with the original schema - schema.pop("$ref") - schema.update(ref_schema) + + if top_schema is not None: + # Create the $defs key if it doesn't exist + if "$defs" not in top_schema: + top_schema["$defs"] = {} + + ref_schema = resolve_schema_references(ref_schema, openapi_schema, top_schema=top_schema) + + # Create the definition reference + top_schema["$defs"][model_name] = ref_schema + + # Update the schema with the definition reference + schema["$ref"] = def_prefix + model_name + else: + # Update the schema with the definition reference + schema.pop("$ref") + schema.update(ref_schema) + top_schema = schema + + # Handle anyOf, oneOf, allOf + for key in ["anyOf", "oneOf", "allOf"]: + if key in schema: + for index, item in enumerate(schema[key]): + item = resolve_schema_references(item, openapi_schema, top_schema=top_schema) + schema[key][index] = item # Handle array items if "type" in schema and schema["type"] == "array" and "items" in schema: - schema["items"] = resolve_schema_references(schema["items"], openapi_schema) + schema["items"] = resolve_schema_references(schema["items"], openapi_schema, top_schema=top_schema) # Handle object properties if "properties" in schema: for prop_name, prop_schema in schema["properties"].items(): - schema["properties"][prop_name] = resolve_schema_references(prop_schema, openapi_schema) + schema["properties"][prop_name] = resolve_schema_references( + prop_schema, openapi_schema, top_schema=top_schema + ) return schema @@ -72,9 +101,6 @@ def clean_schema_for_display(schema: Dict[str, Any]) -> Dict[str, Any]: # Remove common internal fields that are not helpful for LLMs fields_to_remove = [ - "allOf", - "anyOf", - "oneOf", "nullable", "discriminator", "readOnly", @@ -481,7 +507,11 @@ async def http_tool_function(kwargs: Dict[str, Any] = Field(default_factory=dict return response.text # Create a proper input schema for the tool - input_schema = {"type": "object", "properties": properties, "title": f"{operation_id}Arguments"} + input_schema = { + "type": "object", + "properties": properties, + "title": f"{operation_id}Arguments", + } if required_props: input_schema["required"] = required_props @@ -561,7 +591,15 @@ def generate_example_from_schema(schema: Dict[str, Any], model_name: Optional[st } elif model_name == "HTTPValidationError": # Create a realistic validation error example - return {"detail": [{"loc": ["body", "name"], "msg": "field required", "type": "value_error.missing"}]} + return { + "detail": [ + { + "loc": ["body", "name"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } # Handle different types schema_type = schema.get("type") diff --git a/tests/test_http_tools.py b/tests/test_http_tools.py index f8193ae..302d1b9 100644 --- a/tests/test_http_tools.py +++ b/tests/test_http_tools.py @@ -86,7 +86,13 @@ def test_resolve_schema_references(): openapi_schema = { "components": { "schemas": { - "Item": {"type": "object", "properties": {"id": {"type": "integer"}, "name": {"type": "string"}}} + "Item": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"}, + }, + } } } } @@ -141,7 +147,13 @@ def test_create_mcp_tools_from_complex_app(complex_app): assert len(api_tools) == 5, f"Expected 5 API tools, got {len(api_tools)}" # Check for all expected tools with the correct name pattern - tool_operations = ["list_items", "read_item", "create_item", "update_item", "delete_item"] + tool_operations = [ + "list_items", + "read_item", + "create_item", + "update_item", + "delete_item", + ] for operation in tool_operations: matching_tools = [t for t in tools if operation in t.name] assert len(matching_tools) > 0, f"No tool found for operation '{operation}'" diff --git a/tests/test_tool_generation.py b/tests/test_tool_generation.py index 25102d1..e7bcc2d 100644 --- a/tests/test_tool_generation.py +++ b/tests/test_tool_generation.py @@ -1,3 +1,5 @@ +import json + import pytest from fastapi import FastAPI from pydantic import BaseModel @@ -14,6 +16,29 @@ class Item(BaseModel): tags: List[str] = [] +class Task(BaseModel): + id: int + title: str + description: Optional[str] = None + completed: bool = False + required_resources: List[Item] = [] + + +def remove_default_values(schema: dict) -> dict: + if "default" in schema: + schema.pop("default") + + for value in schema.values(): + if isinstance(value, dict): + remove_default_values(value) + + return schema + + +def normalize_json_schema(schema: dict) -> str: + return json.dumps(remove_default_values(schema), sort_keys=True) + + @pytest.fixture def sample_app(): """Create a sample FastAPI app for testing.""" @@ -50,6 +75,14 @@ async def create_item(item: Item): """ return item + @app.get("/tasks/", response_model=List[Task], tags=["tasks"]) + async def list_tasks( + skip: int = 0, + limit: int = 10, + ): + """List all tasks with pagination options.""" + return [] + return app @@ -96,7 +129,10 @@ def test_tool_generation_with_full_schema(sample_app): """Test that MCP tools include full response schema when requested.""" # Create MCP server with full schema for all operations mcp_server = add_mcp_server( - sample_app, serve_tools=True, base_url="http://localhost:8000", describe_full_response_schema=True + sample_app, + serve_tools=True, + base_url="http://localhost:8000", + describe_full_response_schema=True, ) # Extract tools for inspection @@ -109,16 +145,44 @@ def test_tool_generation_with_full_schema(sample_app): continue description = tool.description - # Check that the tool includes information about the Item schema - assert "Item" in description, f"Item schema should be included in the description for {tool.name}" - assert "price" in description, f"Item properties should be included in the description for {tool.name}" + + # Check that the tool includes information about the Item or Task schema + if tool.name == "list_tasks_tasks__get": + model = Task + elif "Item" in description: + model = Item + elif "Task" not in description: + raise ValueError(f"Item or Task schema should be included in the description for {tool.name}") + + assert "price" in description or "required_resources" in description, ( + f"Item or Task properties should be included in the description for {tool.name}" + ) + + # Get the output schema from the description + lines = description.split("\n") + for index, line in enumerate(lines): + if "Output Schema" in line: + index += 2 + break + + # Normalize the output schema + output_schema_str = normalize_json_schema(json.loads("\n".join(lines[index:-1]))) + + # Generate and normalize the model schema + model_schema_str = normalize_json_schema(model.model_json_schema()) + + # Check that the output schema matches the model schema + assert output_schema_str == model_schema_str, f"Output schema does not match model schema for {tool.name}" def test_tool_generation_with_all_responses(sample_app): """Test that MCP tools include all possible responses when requested.""" # Create MCP server with all response status codes mcp_server = add_mcp_server( - sample_app, serve_tools=True, base_url="http://localhost:8000", describe_all_responses=True + sample_app, + serve_tools=True, + base_url="http://localhost:8000", + describe_all_responses=True, ) # Extract tools for inspection