From d6c3da93d8d907c0c24de68c1d17b07999146031 Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 12 May 2025 16:45:35 -0400 Subject: [PATCH 1/3] fastmcp: allow passing Tool directly to .add_tool --- src/mcp/server/fastmcp/server.py | 27 +++++++++++++++----- src/mcp/server/fastmcp/tools/tool_manager.py | 26 ++++++++++++++++--- tests/server/fastmcp/test_tool_manager.py | 27 +++++++++++++++++++- 3 files changed, 69 insertions(+), 11 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c31f29d4c..2b891f2cd 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -10,7 +10,7 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal +from typing import Any, Generic, Literal, overload import anyio import pydantic_core @@ -37,7 +37,7 @@ from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager -from mcp.server.fastmcp.tools import ToolManager +from mcp.server.fastmcp.tools import Tool, ToolManager from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger from mcp.server.fastmcp.utilities.types import Image from mcp.server.lowlevel.helper_types import ReadResourceContents @@ -315,12 +315,24 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent logger.error(f"Error reading resource {uri}: {e}") raise ResourceError(str(e)) + @overload + def add_tool(self, fn: Tool) -> None: ... + + @overload def add_tool( self, fn: AnyFunction, name: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, + ) -> None: ... + + def add_tool( + self, + fn: AnyFunction | Tool, + name: str | None = None, + description: str | None = None, + annotations: ToolAnnotations | None = None, ) -> None: """Add a tool to the server. @@ -328,14 +340,17 @@ def add_tool( with the Context type annotation. See the @tool decorator for examples. Args: - fn: The function to register as a tool + fn: The function to register as a tool or a Tool instance name: Optional name for the tool (defaults to function name) description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information """ - self._tool_manager.add_tool( - fn, name=name, description=description, annotations=annotations - ) + if isinstance(fn, Tool): + self._tool_manager.add_tool(fn) + else: + self._tool_manager.add_tool( + fn, name=name, description=description, annotations=annotations + ) def tool( self, diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index cfdaeb350..d77a17b36 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool @@ -31,17 +31,35 @@ def list_tools(self) -> list[Tool]: """List all registered tools.""" return list(self._tools.values()) + @overload + def add_tool( + self, + fn: Tool, + ) -> Tool: ... + + @overload def add_tool( self, fn: Callable[..., Any], name: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, + ) -> Tool: ... + + def add_tool( + self, + fn: Callable[..., Any] | Tool, + name: str | None = None, + description: str | None = None, + annotations: ToolAnnotations | None = None, ) -> Tool: """Add a tool to the server.""" - tool = Tool.from_function( - fn, name=name, description=description, annotations=annotations - ) + if isinstance(fn, Tool): + tool = fn + else: + tool = Tool.from_function( + fn, name=name, description=description, annotations=annotations + ) existing = self._tools.get(tool.name) if existing: if self.warn_on_duplicate_tools: diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index e36a09d54..7dfdc01c9 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -6,7 +6,8 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.exceptions import ToolError -from mcp.server.fastmcp.tools import ToolManager +from mcp.server.fastmcp.tools import Tool, ToolManager +from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.server.session import ServerSessionT from mcp.shared.context import LifespanContextT from mcp.types import ToolAnnotations @@ -31,6 +32,30 @@ def add(a: int, b: int) -> int: assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" + def test_add_tool_directly(self): + manager = ToolManager() + + def add(a: int, b: int) -> int: + return a + b + + class AddArguments(ArgModelBase): + a: int + b: int + + fn_metadata = FuncMetadata(arg_model=AddArguments) + + original_tool = Tool( + name="add", + description="Add two numbers.", + fn=add, + fn_metadata=fn_metadata, + is_async=False, + parameters=AddArguments.model_json_schema(), + ) + manager.add_tool(original_tool) + saved_tool = manager.get_tool("add") + assert saved_tool == original_tool + @pytest.mark.anyio async def test_async_function(self): """Test registering and running an async function.""" From 09934af7d7131f0bc925083d6119140536d369d8 Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 12 May 2025 20:46:26 -0400 Subject: [PATCH 2/3] lint --- tests/server/fastmcp/test_tool_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 7dfdc01c9..b113b9bdb 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -51,6 +51,8 @@ class AddArguments(ArgModelBase): fn_metadata=fn_metadata, is_async=False, parameters=AddArguments.model_json_schema(), + context_kwarg=None, + annotations=None, ) manager.add_tool(original_tool) saved_tool = manager.get_tool("add") From 140e8a5b42d27e1897dcc1c4e98a12a9cf8f8515 Mon Sep 17 00:00:00 2001 From: vbarda Date: Tue, 13 May 2025 11:57:29 -0400 Subject: [PATCH 3/3] code review --- src/mcp/server/fastmcp/server.py | 27 ++++--------- src/mcp/server/fastmcp/tools/tool_manager.py | 42 +++++++------------- tests/server/fastmcp/test_tool_manager.py | 4 +- 3 files changed, 24 insertions(+), 49 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 2b891f2cd..c1ea5ca33 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -10,7 +10,7 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal, overload +from typing import Any, Generic, Literal import anyio import pydantic_core @@ -315,24 +315,16 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent logger.error(f"Error reading resource {uri}: {e}") raise ResourceError(str(e)) - @overload - def add_tool(self, fn: Tool) -> None: ... + def add_tool_instance(self, tool: Tool) -> None: + """Add a Tool instance to the server.""" + self._tool_manager.add_tool_instance(tool) - @overload def add_tool( self, fn: AnyFunction, name: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, - ) -> None: ... - - def add_tool( - self, - fn: AnyFunction | Tool, - name: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, ) -> None: """Add a tool to the server. @@ -340,17 +332,14 @@ def add_tool( with the Context type annotation. See the @tool decorator for examples. Args: - fn: The function to register as a tool or a Tool instance + fn: The function to register as a tool name: Optional name for the tool (defaults to function name) description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information """ - if isinstance(fn, Tool): - self._tool_manager.add_tool(fn) - else: - self._tool_manager.add_tool( - fn, name=name, description=description, annotations=annotations - ) + self._tool_manager.add_tool( + fn, name=name, description=description, annotations=annotations + ) def tool( self, diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index d77a17b36..6f7d0d9f2 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool @@ -31,42 +31,28 @@ def list_tools(self) -> list[Tool]: """List all registered tools.""" return list(self._tools.values()) - @overload - def add_tool( - self, - fn: Tool, - ) -> Tool: ... + def add_tool_instance(self, tool: Tool) -> Tool: + """Add a Tool instance to the server.""" + existing = self._tools.get(tool.name) + if existing: + if self.warn_on_duplicate_tools: + logger.warning(f"Tool already exists: {tool.name}") + return existing + self._tools[tool.name] = tool + return tool - @overload def add_tool( self, fn: Callable[..., Any], name: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, - ) -> Tool: ... - - def add_tool( - self, - fn: Callable[..., Any] | Tool, - name: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, ) -> Tool: """Add a tool to the server.""" - if isinstance(fn, Tool): - tool = fn - else: - tool = Tool.from_function( - fn, name=name, description=description, annotations=annotations - ) - existing = self._tools.get(tool.name) - if existing: - if self.warn_on_duplicate_tools: - logger.warning(f"Tool already exists: {tool.name}") - return existing - self._tools[tool.name] = tool - return tool + tool = Tool.from_function( + fn, name=name, description=description, annotations=annotations + ) + return self.add_tool_instance(tool) async def call_tool( self, diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index b113b9bdb..9ae73da05 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -32,7 +32,7 @@ def add(a: int, b: int) -> int: assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" - def test_add_tool_directly(self): + def test_add_tool_instance(self): manager = ToolManager() def add(a: int, b: int) -> int: @@ -54,7 +54,7 @@ class AddArguments(ArgModelBase): context_kwarg=None, annotations=None, ) - manager.add_tool(original_tool) + manager.add_tool_instance(original_tool) saved_tool = manager.get_tool("add") assert saved_tool == original_tool