Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ async def call_tool(
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
cancellation_token.link_future(result_future)
actual_tool_output = await result_future
if isinstance(actual_tool_output, ToolResult):
return actual_tool_output
is_error = False
result_str = tool.return_value_as_string(actual_tool_output)
except Exception as e:
Expand Down Expand Up @@ -217,6 +219,9 @@ async def call_tool_stream(
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id))
cancellation_token.link_future(result_future)
actual_tool_output = await result_future
if isinstance(actual_tool_output, ToolResult):
yield actual_tool_output
return
is_error = False
result_str = tool.return_value_as_string(actual_tool_output)
except Exception as e:
Expand Down
98 changes: 75 additions & 23 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from abc import ABC
from typing import Any, Dict, Generic, Sequence, Type, TypeVar

from autogen_core import CancellationToken
from autogen_core.tools import BaseTool
from autogen_core import CancellationToken, Image
from autogen_core.tools import BaseTool, ImageResultContent, TextResultContent, ToolResult
from autogen_core.utils import schema_to_pydantic_model
from pydantic import BaseModel
from pydantic.networks import AnyUrl
Expand Down Expand Up @@ -37,10 +37,21 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):

component_type = "tool"

def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None:
def __init__(
self,
server_params: TServerParams,
tool: Tool,
session: ClientSession | None = None,
max_retries: int = 0,
retry_delay: float = 1.0,
raise_on_error: bool = False,
) -> None:
self._tool = tool
self._server_params = server_params
self._session = session
self._max_retries = max_retries
self._retry_delay = retry_delay
self._raise_on_error = raise_on_error

# Extract name and description
name = tool.name
Expand All @@ -49,12 +60,12 @@ def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSess
# Create the input model from the tool's schema
input_model = schema_to_pydantic_model(tool.inputSchema)

# Use Any as return type since MCP tool returns can vary
return_type: Type[Any] = object
# Use ToolResult as return type
return_type: Type[ToolResult] = ToolResult

super().__init__(input_model, return_type, name, description)

async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> ToolResult:
"""
Run the MCP tool with the provided arguments.

Expand All @@ -63,24 +74,50 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A
cancellation_token (CancellationToken): Token to signal cancellation.

Returns:
Any: The result of the tool execution.

Raises:
Exception: If the operation is cancelled or the tool execution fails.
ToolResult: The result of the tool execution.
"""
# Convert the input model to a dictionary
# Exclude unset values to avoid sending them to the MCP servers which may cause errors
# for many servers.
kwargs = args.model_dump(exclude_unset=True)

if self._session is not None:
# If a session is provided, use it directly.
session = self._session
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
exceptions_to_catch: tuple[Type[BaseException], ...]
if hasattr(builtins, "ExceptionGroup"):
exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup)
else:
exceptions_to_catch = (asyncio.CancelledError,)

async with create_mcp_server_session(self._server_params) as session:
await session.initialize()
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
attempts = 1 + self._max_retries
for attempt in range(attempts):
try:
if cancellation_token.is_cancelled():
raise asyncio.CancelledError("Operation cancelled")

if self._session is not None:
# If a session is provided, use it directly.
session = self._session
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)

async with create_mcp_server_session(self._server_params) as session:
await session.initialize()
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)

except exceptions_to_catch:
# Re-raise these specific exception types directly.
raise
except Exception as e:
if attempt < attempts - 1:
await asyncio.sleep(self._retry_delay)
continue
if self._raise_on_error:
raise
return ToolResult(
name=self.name,
result=[TextResultContent(content=f"Tool {self.name} failed: {str(e)}")],
is_error=True,
)

raise RuntimeError("Unreachable")

def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> list[ContentBlock]:
"""
Expand All @@ -102,7 +139,9 @@ def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) ->
else:
return [TextContent(text=str(payload), type="text")]

async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any:
async def _run(
self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession
) -> ToolResult:
exceptions_to_catch: tuple[Type[BaseException], ...]
if hasattr(builtins, "ExceptionGroup"):
exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup)
Expand All @@ -117,12 +156,23 @@ async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken
cancellation_token.link_future(result_future)
result = await result_future

normalized_content_list = self._normalize_payload_to_content_list(result.content)

if result.isError:
if result.isError and self._raise_on_error:
normalized_content_list = self._normalize_payload_to_content_list(result.content)
serialized_error_message = self.return_value_as_string(normalized_content_list)
raise Exception(serialized_error_message)
return normalized_content_list

result_parts: list[TextResultContent | ImageResultContent] = []
for content in result.content:
if isinstance(content, TextContent):
result_parts.append(TextResultContent(content=content.text))
elif isinstance(content, ImageContent):
result_parts.append(ImageResultContent(content=Image.from_base64(content.data)))
elif isinstance(content, EmbeddedResource):
result_parts.append(TextResultContent(content=content.model_dump_json()))
else:
result_parts.append(TextResultContent(content=str(content)))

return ToolResult(name=self.name, result=result_parts, is_error=result.isError)

except exceptions_to_catch:
# Re-raise these specific exception types directly.
Expand Down Expand Up @@ -156,8 +206,10 @@ async def from_server_params(cls, server_params: TServerParams, tool_name: str)

return cls(server_params=server_params, tool=matching_tool)

def return_value_as_string(self, value: list[Any]) -> str:
def return_value_as_string(self, value: Any) -> str:
"""Return a string representation of the result."""
if isinstance(value, ToolResult):
return value.to_text()

def serialize_item(item: Any) -> dict[str, Any]:
if isinstance(item, (TextContent, ImageContent, AudioContent)):
Expand Down
39 changes: 36 additions & 3 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
async def mcp_server_tools(
server_params: McpServerParams,
session: ClientSession | None = None,
max_retries: int = 0,
retry_delay: float = 1.0,
raise_on_error: bool = False,
) -> list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]:
"""Creates a list of MCP tool adapters that can be used with AutoGen agents.

Expand All @@ -36,6 +39,9 @@ async def mcp_server_tools(
session (ClientSession | None): Optional existing session to use. This is used
when you want to reuse an existing connection to the MCP server. The session
will be reused when creating the MCP tool adapters.
max_retries (int, optional): The maximum number of retries for tool execution. Defaults to 0.
retry_delay (float, optional): The delay in seconds between retries. Defaults to 1.0.
raise_on_error (bool, optional): Whether to raise an exception on tool error. Defaults to False.

Returns:
list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]:
Expand Down Expand Up @@ -203,12 +209,39 @@ async def main() -> None:
tools = await session.list_tools()

if isinstance(server_params, StdioServerParams):
return [StdioMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
return [
StdioMcpToolAdapter(
server_params=server_params,
tool=tool,
session=session,
max_retries=max_retries,
retry_delay=retry_delay,
raise_on_error=raise_on_error,
)
for tool in tools.tools
]
elif isinstance(server_params, SseServerParams):
return [SseMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
return [
SseMcpToolAdapter(
server_params=server_params,
tool=tool,
session=session,
max_retries=max_retries,
retry_delay=retry_delay,
raise_on_error=raise_on_error,
)
for tool in tools.tools
]
elif isinstance(server_params, StreamableHttpServerParams):
return [
StreamableHttpMcpToolAdapter(server_params=server_params, tool=tool, session=session)
StreamableHttpMcpToolAdapter(
server_params=server_params,
tool=tool,
session=session,
max_retries=max_retries,
retry_delay=retry_delay,
raise_on_error=raise_on_error,
)
for tool in tools.tools
]
raise ValueError(f"Unsupported server params type: {type(server_params)}")
41 changes: 37 additions & 4 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class SseMcpToolAdapterConfig(BaseModel):

server_params: SseServerParams
tool: Tool
max_retries: int = 0
retry_delay: float = 1.0
raise_on_error: bool = False


class SseMcpToolAdapter(
Expand Down Expand Up @@ -41,6 +44,9 @@ class SseMcpToolAdapter(
session (ClientSession, optional): The MCP client session to use. If not provided,
it will create a new session. This is useful for testing or when you want to
manage the session lifecycle yourself.
max_retries (int, optional): The maximum number of retries for tool execution. Defaults to 0.
retry_delay (float, optional): The delay in seconds between retries. Defaults to 1.0.
raise_on_error (bool, optional): Whether to raise an exception on tool error. Defaults to False.

Examples:
Use a remote translation service that implements MCP over SSE to create tools
Expand Down Expand Up @@ -90,8 +96,23 @@ async def main() -> None:
component_config_schema = SseMcpToolAdapterConfig
component_provider_override = "autogen_ext.tools.mcp.SseMcpToolAdapter"

def __init__(self, server_params: SseServerParams, tool: Tool, session: ClientSession | None = None) -> None:
super().__init__(server_params=server_params, tool=tool, session=session)
def __init__(
self,
server_params: SseServerParams,
tool: Tool,
session: ClientSession | None = None,
max_retries: int = 0,
retry_delay: float = 1.0,
raise_on_error: bool = False,
) -> None:
super().__init__(
server_params=server_params,
tool=tool,
session=session,
max_retries=max_retries,
retry_delay=retry_delay,
raise_on_error=raise_on_error,
)

def _to_config(self) -> SseMcpToolAdapterConfig:
"""
Expand All @@ -100,7 +121,13 @@ def _to_config(self) -> SseMcpToolAdapterConfig:
Returns:
SseMcpToolAdapterConfig: The configuration of the adapter.
"""
return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
return SseMcpToolAdapterConfig(
server_params=self._server_params,
tool=self._tool,
max_retries=self._max_retries,
retry_delay=self._retry_delay,
raise_on_error=self._raise_on_error,
)

@classmethod
def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self:
Expand All @@ -113,4 +140,10 @@ def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self:
Returns:
SseMcpToolAdapter: An instance of SseMcpToolAdapter.
"""
return cls(server_params=config.server_params, tool=config.tool)
return cls(
server_params=config.server_params,
tool=config.tool,
max_retries=config.max_retries,
retry_delay=config.retry_delay,
raise_on_error=config.raise_on_error,
)
41 changes: 37 additions & 4 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class StdioMcpToolAdapterConfig(BaseModel):

server_params: StdioServerParams
tool: Tool
max_retries: int = 0
retry_delay: float = 1.0
raise_on_error: bool = False


class StdioMcpToolAdapter(
Expand Down Expand Up @@ -41,15 +44,33 @@ class StdioMcpToolAdapter(
session (ClientSession, optional): The MCP client session to use. If not provided,
a new session will be created. This is useful for testing or when you want to
manage the session lifecycle yourself.
max_retries (int, optional): The maximum number of retries for tool execution. Defaults to 0.
retry_delay (float, optional): The delay in seconds between retries. Defaults to 1.0.
raise_on_error (bool, optional): Whether to raise an exception on tool error. Defaults to False.

See :func:`~autogen_ext.tools.mcp.mcp_server_tools` for examples.
"""

component_config_schema = StdioMcpToolAdapterConfig
component_provider_override = "autogen_ext.tools.mcp.StdioMcpToolAdapter"

def __init__(self, server_params: StdioServerParams, tool: Tool, session: ClientSession | None = None) -> None:
super().__init__(server_params=server_params, tool=tool, session=session)
def __init__(
self,
server_params: StdioServerParams,
tool: Tool,
session: ClientSession | None = None,
max_retries: int = 0,
retry_delay: float = 1.0,
raise_on_error: bool = False,
) -> None:
super().__init__(
server_params=server_params,
tool=tool,
session=session,
max_retries=max_retries,
retry_delay=retry_delay,
raise_on_error=raise_on_error,
)

def _to_config(self) -> StdioMcpToolAdapterConfig:
"""
Expand All @@ -58,7 +79,13 @@ def _to_config(self) -> StdioMcpToolAdapterConfig:
Returns:
StdioMcpToolAdapterConfig: The configuration of the adapter.
"""
return StdioMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool)
return StdioMcpToolAdapterConfig(
server_params=self._server_params,
tool=self._tool,
max_retries=self._max_retries,
retry_delay=self._retry_delay,
raise_on_error=self._raise_on_error,
)

@classmethod
def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self:
Expand All @@ -71,4 +98,10 @@ def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self:
Returns:
StdioMcpToolAdapter: An instance of StdioMcpToolAdapter.
"""
return cls(server_params=config.server_params, tool=config.tool)
return cls(
server_params=config.server_params,
tool=config.tool,
max_retries=config.max_retries,
retry_delay=config.retry_delay,
raise_on_error=config.raise_on_error,
)
Loading