diff --git a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py index 40b1ce47d991..0046c76e30a1 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py @@ -116,6 +116,8 @@ async def call_tool( result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) cancellation_token.link_future(result_future) actual_tool_output = await result_future + if isinstance(actual_tool_output, ToolResult): + return actual_tool_output is_error = False result_str = tool.return_value_as_string(actual_tool_output) except Exception as e: @@ -217,6 +219,9 @@ async def call_tool_stream( result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token, call_id=call_id)) cancellation_token.link_future(result_future) actual_tool_output = await result_future + if isinstance(actual_tool_output, ToolResult): + yield actual_tool_output + return is_error = False result_str = tool.return_value_as_string(actual_tool_output) except Exception as e: diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py index 314dbe15e7ba..693ee83de6d3 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py @@ -4,8 +4,8 @@ from abc import ABC from typing import Any, Dict, Generic, Sequence, Type, TypeVar -from autogen_core import CancellationToken -from autogen_core.tools import BaseTool +from autogen_core import CancellationToken, Image +from autogen_core.tools import BaseTool, ImageResultContent, TextResultContent, ToolResult from autogen_core.utils import schema_to_pydantic_model from pydantic import BaseModel from pydantic.networks import AnyUrl @@ -37,10 +37,21 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]): component_type = "tool" - def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None: + def __init__( + self, + server_params: TServerParams, + tool: Tool, + session: ClientSession | None = None, + max_retries: int = 0, + retry_delay: float = 1.0, + raise_on_error: bool = False, + ) -> None: self._tool = tool self._server_params = server_params self._session = session + self._max_retries = max_retries + self._retry_delay = retry_delay + self._raise_on_error = raise_on_error # Extract name and description name = tool.name @@ -49,12 +60,12 @@ def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSess # Create the input model from the tool's schema input_model = schema_to_pydantic_model(tool.inputSchema) - # Use Any as return type since MCP tool returns can vary - return_type: Type[Any] = object + # Use ToolResult as return type + return_type: Type[ToolResult] = ToolResult super().__init__(input_model, return_type, name, description) - async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: + async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> ToolResult: """ Run the MCP tool with the provided arguments. @@ -63,24 +74,50 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A cancellation_token (CancellationToken): Token to signal cancellation. Returns: - Any: The result of the tool execution. - - Raises: - Exception: If the operation is cancelled or the tool execution fails. + ToolResult: The result of the tool execution. """ # Convert the input model to a dictionary # Exclude unset values to avoid sending them to the MCP servers which may cause errors # for many servers. kwargs = args.model_dump(exclude_unset=True) - if self._session is not None: - # If a session is provided, use it directly. - session = self._session - return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + exceptions_to_catch: tuple[Type[BaseException], ...] + if hasattr(builtins, "ExceptionGroup"): + exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup) + else: + exceptions_to_catch = (asyncio.CancelledError,) - async with create_mcp_server_session(self._server_params) as session: - await session.initialize() - return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + attempts = 1 + self._max_retries + for attempt in range(attempts): + try: + if cancellation_token.is_cancelled(): + raise asyncio.CancelledError("Operation cancelled") + + if self._session is not None: + # If a session is provided, use it directly. + session = self._session + return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + + async with create_mcp_server_session(self._server_params) as session: + await session.initialize() + return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + + except exceptions_to_catch: + # Re-raise these specific exception types directly. + raise + except Exception as e: + if attempt < attempts - 1: + await asyncio.sleep(self._retry_delay) + continue + if self._raise_on_error: + raise + return ToolResult( + name=self.name, + result=[TextResultContent(content=f"Tool {self.name} failed: {str(e)}")], + is_error=True, + ) + + raise RuntimeError("Unreachable") def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> list[ContentBlock]: """ @@ -102,7 +139,9 @@ def _normalize_payload_to_content_list(self, payload: Sequence[ContentBlock]) -> else: return [TextContent(text=str(payload), type="text")] - async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any: + async def _run( + self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession + ) -> ToolResult: exceptions_to_catch: tuple[Type[BaseException], ...] if hasattr(builtins, "ExceptionGroup"): exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup) @@ -117,12 +156,23 @@ async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken cancellation_token.link_future(result_future) result = await result_future - normalized_content_list = self._normalize_payload_to_content_list(result.content) - - if result.isError: + if result.isError and self._raise_on_error: + normalized_content_list = self._normalize_payload_to_content_list(result.content) serialized_error_message = self.return_value_as_string(normalized_content_list) raise Exception(serialized_error_message) - return normalized_content_list + + result_parts: list[TextResultContent | ImageResultContent] = [] + for content in result.content: + if isinstance(content, TextContent): + result_parts.append(TextResultContent(content=content.text)) + elif isinstance(content, ImageContent): + result_parts.append(ImageResultContent(content=Image.from_base64(content.data))) + elif isinstance(content, EmbeddedResource): + result_parts.append(TextResultContent(content=content.model_dump_json())) + else: + result_parts.append(TextResultContent(content=str(content))) + + return ToolResult(name=self.name, result=result_parts, is_error=result.isError) except exceptions_to_catch: # Re-raise these specific exception types directly. @@ -156,8 +206,10 @@ async def from_server_params(cls, server_params: TServerParams, tool_name: str) return cls(server_params=server_params, tool=matching_tool) - def return_value_as_string(self, value: list[Any]) -> str: + def return_value_as_string(self, value: Any) -> str: """Return a string representation of the result.""" + if isinstance(value, ToolResult): + return value.to_text() def serialize_item(item: Any) -> dict[str, Any]: if isinstance(item, (TextContent, ImageContent, AudioContent)): diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py index 66f8e7b7e7b3..0992de10aec6 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py @@ -10,6 +10,9 @@ async def mcp_server_tools( server_params: McpServerParams, session: ClientSession | None = None, + max_retries: int = 0, + retry_delay: float = 1.0, + raise_on_error: bool = False, ) -> list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]: """Creates a list of MCP tool adapters that can be used with AutoGen agents. @@ -36,6 +39,9 @@ async def mcp_server_tools( session (ClientSession | None): Optional existing session to use. This is used when you want to reuse an existing connection to the MCP server. The session will be reused when creating the MCP tool adapters. + max_retries (int, optional): The maximum number of retries for tool execution. Defaults to 0. + retry_delay (float, optional): The delay in seconds between retries. Defaults to 1.0. + raise_on_error (bool, optional): Whether to raise an exception on tool error. Defaults to False. Returns: list[StdioMcpToolAdapter | SseMcpToolAdapter | StreamableHttpMcpToolAdapter]: @@ -203,12 +209,39 @@ async def main() -> None: tools = await session.list_tools() if isinstance(server_params, StdioServerParams): - return [StdioMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools] + return [ + StdioMcpToolAdapter( + server_params=server_params, + tool=tool, + session=session, + max_retries=max_retries, + retry_delay=retry_delay, + raise_on_error=raise_on_error, + ) + for tool in tools.tools + ] elif isinstance(server_params, SseServerParams): - return [SseMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools] + return [ + SseMcpToolAdapter( + server_params=server_params, + tool=tool, + session=session, + max_retries=max_retries, + retry_delay=retry_delay, + raise_on_error=raise_on_error, + ) + for tool in tools.tools + ] elif isinstance(server_params, StreamableHttpServerParams): return [ - StreamableHttpMcpToolAdapter(server_params=server_params, tool=tool, session=session) + StreamableHttpMcpToolAdapter( + server_params=server_params, + tool=tool, + session=session, + max_retries=max_retries, + retry_delay=retry_delay, + raise_on_error=raise_on_error, + ) for tool in tools.tools ] raise ValueError(f"Unsupported server params type: {type(server_params)}") diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py index c77ec8607422..347b9fbf4a84 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py @@ -13,6 +13,9 @@ class SseMcpToolAdapterConfig(BaseModel): server_params: SseServerParams tool: Tool + max_retries: int = 0 + retry_delay: float = 1.0 + raise_on_error: bool = False class SseMcpToolAdapter( @@ -41,6 +44,9 @@ class SseMcpToolAdapter( session (ClientSession, optional): The MCP client session to use. If not provided, it will create a new session. This is useful for testing or when you want to manage the session lifecycle yourself. + max_retries (int, optional): The maximum number of retries for tool execution. Defaults to 0. + retry_delay (float, optional): The delay in seconds between retries. Defaults to 1.0. + raise_on_error (bool, optional): Whether to raise an exception on tool error. Defaults to False. Examples: Use a remote translation service that implements MCP over SSE to create tools @@ -90,8 +96,23 @@ async def main() -> None: component_config_schema = SseMcpToolAdapterConfig component_provider_override = "autogen_ext.tools.mcp.SseMcpToolAdapter" - def __init__(self, server_params: SseServerParams, tool: Tool, session: ClientSession | None = None) -> None: - super().__init__(server_params=server_params, tool=tool, session=session) + def __init__( + self, + server_params: SseServerParams, + tool: Tool, + session: ClientSession | None = None, + max_retries: int = 0, + retry_delay: float = 1.0, + raise_on_error: bool = False, + ) -> None: + super().__init__( + server_params=server_params, + tool=tool, + session=session, + max_retries=max_retries, + retry_delay=retry_delay, + raise_on_error=raise_on_error, + ) def _to_config(self) -> SseMcpToolAdapterConfig: """ @@ -100,7 +121,13 @@ def _to_config(self) -> SseMcpToolAdapterConfig: Returns: SseMcpToolAdapterConfig: The configuration of the adapter. """ - return SseMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + return SseMcpToolAdapterConfig( + server_params=self._server_params, + tool=self._tool, + max_retries=self._max_retries, + retry_delay=self._retry_delay, + raise_on_error=self._raise_on_error, + ) @classmethod def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self: @@ -113,4 +140,10 @@ def _from_config(cls, config: SseMcpToolAdapterConfig) -> Self: Returns: SseMcpToolAdapter: An instance of SseMcpToolAdapter. """ - return cls(server_params=config.server_params, tool=config.tool) + return cls( + server_params=config.server_params, + tool=config.tool, + max_retries=config.max_retries, + retry_delay=config.retry_delay, + raise_on_error=config.raise_on_error, + ) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py index bbe7c6ca0752..92abeb76c329 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py @@ -13,6 +13,9 @@ class StdioMcpToolAdapterConfig(BaseModel): server_params: StdioServerParams tool: Tool + max_retries: int = 0 + retry_delay: float = 1.0 + raise_on_error: bool = False class StdioMcpToolAdapter( @@ -41,6 +44,9 @@ class StdioMcpToolAdapter( session (ClientSession, optional): The MCP client session to use. If not provided, a new session will be created. This is useful for testing or when you want to manage the session lifecycle yourself. + max_retries (int, optional): The maximum number of retries for tool execution. Defaults to 0. + retry_delay (float, optional): The delay in seconds between retries. Defaults to 1.0. + raise_on_error (bool, optional): Whether to raise an exception on tool error. Defaults to False. See :func:`~autogen_ext.tools.mcp.mcp_server_tools` for examples. """ @@ -48,8 +54,23 @@ class StdioMcpToolAdapter( component_config_schema = StdioMcpToolAdapterConfig component_provider_override = "autogen_ext.tools.mcp.StdioMcpToolAdapter" - def __init__(self, server_params: StdioServerParams, tool: Tool, session: ClientSession | None = None) -> None: - super().__init__(server_params=server_params, tool=tool, session=session) + def __init__( + self, + server_params: StdioServerParams, + tool: Tool, + session: ClientSession | None = None, + max_retries: int = 0, + retry_delay: float = 1.0, + raise_on_error: bool = False, + ) -> None: + super().__init__( + server_params=server_params, + tool=tool, + session=session, + max_retries=max_retries, + retry_delay=retry_delay, + raise_on_error=raise_on_error, + ) def _to_config(self) -> StdioMcpToolAdapterConfig: """ @@ -58,7 +79,13 @@ def _to_config(self) -> StdioMcpToolAdapterConfig: Returns: StdioMcpToolAdapterConfig: The configuration of the adapter. """ - return StdioMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + return StdioMcpToolAdapterConfig( + server_params=self._server_params, + tool=self._tool, + max_retries=self._max_retries, + retry_delay=self._retry_delay, + raise_on_error=self._raise_on_error, + ) @classmethod def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self: @@ -71,4 +98,10 @@ def _from_config(cls, config: StdioMcpToolAdapterConfig) -> Self: Returns: StdioMcpToolAdapter: An instance of StdioMcpToolAdapter. """ - return cls(server_params=config.server_params, tool=config.tool) + return cls( + server_params=config.server_params, + tool=config.tool, + max_retries=config.max_retries, + retry_delay=config.retry_delay, + raise_on_error=config.raise_on_error, + ) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_streamable_http.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_streamable_http.py index a7df719c2703..6ffaf39d70a9 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_streamable_http.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_streamable_http.py @@ -13,6 +13,9 @@ class StreamableHttpMcpToolAdapterConfig(BaseModel): server_params: StreamableHttpServerParams tool: Tool + max_retries: int = 0 + retry_delay: float = 1.0 + raise_on_error: bool = False class StreamableHttpMcpToolAdapter( @@ -42,6 +45,9 @@ class StreamableHttpMcpToolAdapter( session (ClientSession, optional): The MCP client session to use. If not provided, it will create a new session. This is useful for testing or when you want to manage the session lifecycle yourself. + max_retries (int, optional): The maximum number of retries for tool execution. Defaults to 0. + retry_delay (float, optional): The delay in seconds between retries. Defaults to 1.0. + raise_on_error (bool, optional): Whether to raise an exception on tool error. Defaults to False. Examples: Use a remote translation service that implements MCP over Streamable HTTP to @@ -50,11 +56,12 @@ class StreamableHttpMcpToolAdapter( .. code-block:: python import asyncio - from autogen_ext.models.openai import OpenAIChatCompletionClient - from autogen_ext.tools.mcp import StreamableHttpMcpToolAdapter, StreamableHttpServerParams + from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.ui import Console from autogen_core import CancellationToken + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.tools.mcp import StreamableHttpMcpToolAdapter, StreamableHttpServerParams async def main() -> None: @@ -94,9 +101,22 @@ async def main() -> None: component_provider_override = "autogen_ext.tools.mcp.StreamableHttpMcpToolAdapter" def __init__( - self, server_params: StreamableHttpServerParams, tool: Tool, session: ClientSession | None = None + self, + server_params: StreamableHttpServerParams, + tool: Tool, + session: ClientSession | None = None, + max_retries: int = 0, + retry_delay: float = 1.0, + raise_on_error: bool = False, ) -> None: - super().__init__(server_params=server_params, tool=tool, session=session) + super().__init__( + server_params=server_params, + tool=tool, + session=session, + max_retries=max_retries, + retry_delay=retry_delay, + raise_on_error=raise_on_error, + ) def _to_config(self) -> StreamableHttpMcpToolAdapterConfig: """ @@ -105,7 +125,13 @@ def _to_config(self) -> StreamableHttpMcpToolAdapterConfig: Returns: StreamableHttpMcpToolAdapterConfig: The configuration of the adapter. """ - return StreamableHttpMcpToolAdapterConfig(server_params=self._server_params, tool=self._tool) + return StreamableHttpMcpToolAdapterConfig( + server_params=self._server_params, + tool=self._tool, + max_retries=self._max_retries, + retry_delay=self._retry_delay, + raise_on_error=self._raise_on_error, + ) @classmethod def _from_config(cls, config: StreamableHttpMcpToolAdapterConfig) -> Self: @@ -118,4 +144,10 @@ def _from_config(cls, config: StreamableHttpMcpToolAdapterConfig) -> Self: Returns: StreamableHttpMcpToolAdapter: An instance of StreamableHttpMcpToolAdapter. """ - return cls(server_params=config.server_params, tool=config.tool) + return cls( + server_params=config.server_params, + tool=config.tool, + max_retries=config.max_retries, + retry_delay=config.retry_delay, + raise_on_error=config.raise_on_error, + ) diff --git a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py index fbcbccc3e6bb..2b9cd9696a3d 100644 --- a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py +++ b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py @@ -8,7 +8,7 @@ import pytest from _pytest.logging import LogCaptureFixture # type: ignore[import] from autogen_core import CancellationToken -from autogen_core.tools import Workbench +from autogen_core.tools import ToolResult, Workbench from autogen_core.utils import schema_to_pydantic_model from autogen_ext.tools.mcp import ( McpSessionActor, @@ -186,7 +186,9 @@ async def test_mcp_tool_execution( cancellation_token=cancellation_token, ) - assert result == mock_tool_response.content + assert isinstance(result, ToolResult) + assert result.is_error is False + assert result.result[0].content == "test_output" mock_session.initialize.assert_called_once() mock_session.call_tool.assert_called_once() @@ -387,7 +389,9 @@ async def test_sse_tool_execution( cancellation_token=CancellationToken(), ) - assert result == mock_sse_session.call_tool.return_value.content + assert isinstance(result, ToolResult) + assert result.is_error is False + assert result.result[0].content == "test_output" mock_sse_session.initialize.assert_called_once() mock_sse_session.call_tool.assert_called_once() @@ -503,7 +507,9 @@ async def test_streamable_http_tool_execution( cancellation_token=CancellationToken(), ) - assert result == mock_streamable_http_session.call_tool.return_value.content + assert isinstance(result, ToolResult) + assert result.is_error is False + assert result.result[0].content == "test_output" mock_streamable_http_session.initialize.assert_called_once() mock_streamable_http_session.call_tool.assert_called_once() @@ -750,7 +756,10 @@ async def test_lazy_init_and_finalize_cleanup() -> None: actor = workbench._actor # type: ignore[reportPrivateUsage] del workbench - await asyncio.sleep(0.1) + for _ in range(5): + import gc + gc.collect() + await asyncio.sleep(0.5) assert actor._active is False @@ -776,12 +785,18 @@ async def test_del_to_new_event_loop_when_get_event_loop_fails() -> None: def cleanup() -> None: nonlocal workbench del workbench + for _ in range(5): + import gc + gc.collect() t = threading.Thread(target=cleanup) t.start() t.join() - await asyncio.sleep(0.1) + for _ in range(5): + import gc + gc.collect() + await asyncio.sleep(0.5) assert actor._active is False # type: ignore[reportPrivateUsage] @@ -869,11 +884,12 @@ async def test_mcp_tool_adapter_run_error( mock_session.call_tool.return_value = mock_error_tool_response args = {"test_param": "test_value"} - with pytest.raises(Exception) as excinfo: - await adapter._run(args=args, cancellation_token=cancellation_token, session=mock_session) # type: ignore[reportPrivateUsage] + result = await adapter._run(args=args, cancellation_token=cancellation_token, session=mock_session) # type: ignore[reportPrivateUsage] mock_session.call_tool.assert_called_once_with(name=sample_tool.name, arguments=args) - assert adapter.return_value_as_string([TextContent(text="error output", type="text")]) in str(excinfo.value) + assert isinstance(result, ToolResult) + assert result.is_error is True + assert result.result[0].content == "error output" @pytest.mark.asyncio @@ -928,3 +944,105 @@ def test_return_value_as_string_with_resource_link(sample_tool: Tool, sample_ser assert '"type": "resource_link"' in result assert '"name": "test_link"' in result assert '"uri": "http://example.com/"' in result # AnyUrl normalizes with trailing slash + + +@pytest.mark.asyncio +async def test_mcp_tool_adapter_retry_success( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + mock_tool_response: MagicMock, + cancellation_token: CancellationToken, +) -> None: + """Test that McpToolAdapter retries and succeeds if subsequent attempt works.""" + adapter = StdioMcpToolAdapter( + server_params=sample_server_params, tool=sample_tool, session=mock_session, max_retries=2, retry_delay=0.01 + ) + + # First call raises an exception, second call succeeds + mock_session.call_tool.side_effect = [Exception("Temporary network error"), mock_tool_response] + + args = {"test_param": "test"} + result = await adapter.run_json( + args=schema_to_pydantic_model(sample_tool.inputSchema)(**args).model_dump(), + cancellation_token=cancellation_token, + ) + + assert isinstance(result, ToolResult) + assert result.is_error is False + assert result.result[0].content == "test_output" + assert mock_session.call_tool.call_count == 2 + + +@pytest.mark.asyncio +async def test_mcp_tool_adapter_retry_failure( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + cancellation_token: CancellationToken, +) -> None: + """Test that McpToolAdapter retries max_retries times and then returns is_error=True on persistent error.""" + adapter = StdioMcpToolAdapter( + server_params=sample_server_params, tool=sample_tool, session=mock_session, max_retries=2, retry_delay=0.01 + ) + + # Always raise an exception + mock_session.call_tool.side_effect = Exception("Persistent network error") + + args = {"test_param": "test"} + result = await adapter.run_json( + args=schema_to_pydantic_model(sample_tool.inputSchema)(**args).model_dump(), + cancellation_token=cancellation_token, + ) + + assert isinstance(result, ToolResult) + assert result.is_error is True + assert "Persistent network error" in result.result[0].content + assert mock_session.call_tool.call_count == 3 + + +@pytest.mark.asyncio +async def test_mcp_tool_adapter_workbench_integration_error( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + cancellation_token: CancellationToken, +) -> None: + """Test that McpToolAdapter integrated in StaticWorkbench handles execution error and returns is_error=True.""" + adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool, session=mock_session) + mock_session.call_tool.side_effect = Exception("Connection timeout (MCP server at localhost:9000)") + + from autogen_core.tools import StaticWorkbench + + workbench = StaticWorkbench(tools=[adapter]) + + result = await workbench.call_tool( + name=adapter.name, + arguments={"test_param": "test"}, + cancellation_token=cancellation_token, + ) + + assert isinstance(result, ToolResult) + assert result.is_error is True + assert "Connection timeout (MCP server at localhost:9000)" in result.result[0].content + + +@pytest.mark.asyncio +async def test_mcp_tool_adapter_raise_on_error_compatibility( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + cancellation_token: CancellationToken, +) -> None: + """Test that McpToolAdapter raises exceptions if raise_on_error is True.""" + adapter = StdioMcpToolAdapter( + server_params=sample_server_params, tool=sample_tool, session=mock_session, raise_on_error=True + ) + mock_session.call_tool.side_effect = Exception("Fatal connection failure") + + args = {"test_param": "test"} + with pytest.raises(Exception, match="Fatal connection failure"): + await adapter.run_json( + args=schema_to_pydantic_model(sample_tool.inputSchema)(**args).model_dump(), + cancellation_token=cancellation_token, + )