diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..5de0b79c2 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,11 @@ +coverage: + status: + project: + default: + target: 90% # overall coverage threshold + patch: + default: + target: 90% # patch coverage threshold + base: auto + # Only post patch coverage on decreases + only_pulls: true \ No newline at end of file diff --git a/src/strands/_async.py b/src/strands/_async.py new file mode 100644 index 000000000..976487c37 --- /dev/null +++ b/src/strands/_async.py @@ -0,0 +1,31 @@ +"""Private async execution utilities.""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Awaitable, Callable, TypeVar + +T = TypeVar("T") + + +def run_async(async_func: Callable[[], Awaitable[T]]) -> T: + """Run an async function in a separate thread to avoid event loop conflicts. + + This utility handles the common pattern of running async code from sync contexts + by using ThreadPoolExecutor to isolate the async execution. + + Args: + async_func: A callable that returns an awaitable + + Returns: + The result of the async function + """ + + async def execute_async() -> T: + return await async_func() + + def execute() -> T: + return asyncio.run(execute_async()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1de75cfd2..3c735f23b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,13 +9,12 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ -import asyncio import json import logging import random import warnings -from concurrent.futures import ThreadPoolExecutor from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, @@ -32,7 +31,11 @@ from pydantic import BaseModel from .. import _identifier +from .._async import run_async from ..event_loop.event_loop import event_loop_cycle + +if TYPE_CHECKING: + from ..experimental.tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -167,12 +170,7 @@ async def acall() -> ToolResult: return tool_results[0] - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() + tool_result = run_async(acall) if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -215,7 +213,7 @@ def __init__( self, model: Union[Model, str, None] = None, messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], Any]]] = None, + tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None, system_prompt: Optional[str] = None, structured_output_model: Optional[Type[BaseModel]] = None, callback_handler: Optional[ @@ -248,6 +246,7 @@ def __init__( - File paths (e.g., "/path/to/tool.py") - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) + - ToolProvider instances for managed tool collections - Functions decorated with `@strands.tool` decorator. If provided, only these tools will be available. If None, all tools will be available. @@ -423,17 +422,11 @@ def __call__( - state: The final state of the event loop - structured_output: Parsed structured output when structured_output_model was specified """ - - def execute() -> AgentResult: - return asyncio.run( - self.invoke_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs - ) + return run_async( + lambda: self.invoke_async( + prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs ) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + ) async def invoke_async( self, @@ -505,13 +498,8 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> category=DeprecationWarning, stacklevel=2, ) - - def execute() -> T: - return asyncio.run(self.structured_output_async(output_model, prompt)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + + return run_async(lambda: self.structured_output_async(output_model, prompt)) async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. @@ -529,6 +517,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu Raises: ValueError: If no conversation history or prompt is provided. + - """ if self._interrupt_state.activated: raise RuntimeError("cannot call structured output during interrupt") @@ -583,6 +572,25 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + def cleanup(self) -> None: + """Clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through finalizers as a fallback, but explicit cleanup is recommended. + """ + self.tool_registry.cleanup() + + def __del__(self) -> None: + """Clean up resources when agent is garbage collected.""" + # __del__ is called even when an exception is thrown in the constructor, + # so there is no guarantee tool_registry was set.. + if hasattr(self, "tool_registry"): + self.tool_registry.cleanup() + async def stream_async( self, prompt: AgentInput = None, diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index 86618c153..188c80c69 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -3,6 +3,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ +from . import tools from .agent_config import config_to_agent -__all__ = ["config_to_agent"] +__all__ = ["config_to_agent", "tools"] diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index d08f89cf9..f65afb57d 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -18,8 +18,6 @@ import jsonschema from jsonschema import ValidationError -from ..agent import Agent - # JSON Schema for agent configuration AGENT_CONFIG_SCHEMA = { "$schema": "http://json-schema.org/draft-07/schema#", @@ -53,7 +51,7 @@ _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) -def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Agent: +def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any: """Create an Agent from a configuration file or dictionary. This function supports tools that can be loaded declaratively (file paths, module names, @@ -134,5 +132,8 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A # Override with any additional kwargs provided agent_kwargs.update(kwargs) + # Import Agent at runtime to avoid circular imports + from ..agent import Agent + # Create and return Agent return Agent(**agent_kwargs) diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..ad693f8ac --- /dev/null +++ b/src/strands/experimental/tools/__init__.py @@ -0,0 +1,5 @@ +"""Experimental tools package.""" + +from .tool_provider import ToolProvider + +__all__ = ["ToolProvider"] diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py new file mode 100644 index 000000000..2c79ceafc --- /dev/null +++ b/src/strands/experimental/tools/tool_provider.py @@ -0,0 +1,52 @@ +"""Tool provider interface.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Sequence + +if TYPE_CHECKING: + from ...types.tools import AgentTool + + +class ToolProvider(ABC): + """Interface for providing tools with lifecycle management. + + Provides a way to load a collection of tools and clean them up + when done, with lifecycle managed by the agent. + """ + + @abstractmethod + async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: + """Load and return the tools in this provider. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of tools that are ready to use. + """ + pass + + @abstractmethod + def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. + + Args: + consumer_id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + """ + pass + + @abstractmethod + def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + This method must be idempotent - calling it multiple times with the same ID + should have no additional effect after the first call. + + Provider may clean up resources when no consumers remain. + + Args: + consumer_id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + """ + pass diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 07e63577d..1628a8a9d 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,15 +3,14 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ -import asyncio import logging import warnings from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Any, Union +from .._async import run_async from ..agent import AgentResult from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage @@ -199,12 +198,7 @@ def __call__( invocation_state.update(kwargs) warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) def serialize_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot of the orchestrator state.""" diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 1dbbfc3af..0aaa6c7a3 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,12 +18,12 @@ import copy import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api +from .._async import run_async from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer @@ -399,12 +399,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7542b1b85..3d9dc00c8 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -17,13 +17,14 @@ import json import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Tuple from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from .._async import run_async +from ..agent import Agent +from ..agent.agent_result import AgentResult from ..agent.state import AgentState from ..telemetry import get_tracer from ..tools.decorator import tool @@ -254,12 +255,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index d95c54fed..cfa841c46 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -7,7 +7,7 @@ """ from .mcp_agent_tool import MCPAgentTool -from .mcp_client import MCPClient +from .mcp_client import MCPClient, ToolFilters from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index acc48443c..af0c069a1 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -28,26 +28,29 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None: + def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: """Initialize a new MCPAgentTool instance. Args: mcp_tool: The MCP tool to adapt mcp_client: The MCP server connection to use for tool invocation + name_override: Optional name to use for the agent tool (for disambiguation) + If None, uses the original MCP tool name """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client + self._agent_tool_name = name_override or mcp_tool.name @property def tool_name(self) -> str: """Get the name of the tool. Returns: - str: The name of the MCP tool + str: The agent-facing name of the tool (may be disambiguated) """ - return self.mcp_tool.name + return self._agent_tool_name @property def tool_spec(self) -> ToolSpec: @@ -63,7 +66,7 @@ def tool_spec(self) -> ToolSpec: spec: ToolSpec = { "inputSchema": {"json": self.mcp_tool.inputSchema}, - "name": self.mcp_tool.name, + "name": self.tool_name, # Use agent-facing name in spec "description": description, } @@ -100,7 +103,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await self.mcp_client.call_tool_async( tool_use_id=tool_use["toolUseId"], - name=self.tool_name, + name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], ) yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 8148e149a..baeed9d13 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -16,7 +16,7 @@ from concurrent import futures from datetime import timedelta from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast import anyio from mcp import ClientSession, ListToolsResult @@ -25,11 +25,13 @@ from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from typing_extensions import Protocol, TypedDict +from ...experimental.tools import ToolProvider from ...types import PaginatedList -from ...types.exceptions import MCPClientInitializationError +from ...types.exceptions import MCPClientInitializationError, ToolProviderException from ...types.media import ImageFormat -from ...types.tools import ToolResultContent, ToolResultStatus +from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport @@ -38,6 +40,26 @@ T = TypeVar("T") + +class _ToolFilterCallback(Protocol): + def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... + + +_ToolMatcher = str | Pattern[str] | _ToolFilterCallback + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + """ + + allowed: list[_ToolMatcher] + rejected: list[_ToolMatcher] + + MIME_TO_FORMAT: Dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -53,7 +75,7 @@ ) -class MCPClient: +class MCPClient(ToolProvider): """Represents a connection to a Model Context Protocol (MCP) server. This class implements a context manager pattern for efficient connection management, @@ -63,17 +85,32 @@ class MCPClient: The connection runs in a background thread to avoid blocking the main application thread while maintaining communication with the MCP service. When structured content is available from MCP tools, it will be returned as the last item in the content array of the ToolResult. + + Warning: + This class implements the experimental ToolProvider interface and its methods + are subject to change. """ - def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30): + def __init__( + self, + transport_callable: Callable[[], MCPTransport], + *, + startup_timeout: int = 30, + tool_filters: ToolFilters | None = None, + prefix: str | None = None, + ): """Initialize a new MCP Server connection. Args: transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple startup_timeout: Timeout after which MCP server initialization should be cancelled Defaults to 30. + tool_filters: Optional filters to apply to tools. + prefix: Optional prefix for tool names. """ self._startup_timeout = startup_timeout + self._tool_filters = tool_filters + self._prefix = prefix mcp_instrumentation() self._session_id = uuid.uuid4() @@ -87,6 +124,9 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti self._background_thread: threading.Thread | None = None self._background_thread_session: ClientSession | None = None self._background_thread_event_loop: AbstractEventLoop | None = None + self._loaded_tools: list[MCPAgentTool] | None = None + self._tool_provider_started = False + self._consumers: set[Any] = set() def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -137,6 +177,101 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client initialization failed") from e return self + # ToolProvider interface methods (experimental, as ToolProvider is experimental) + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: + """Load and return tools from the MCP server. + + This method implements the ToolProvider interface by loading tools + from the MCP server and caching them for reuse. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of AgentTool instances from the MCP server. + """ + logger.debug( + "started=<%s>, cached_tools=<%s> | loading tools", + self._tool_provider_started, + self._loaded_tools is not None, + ) + + if not self._tool_provider_started: + try: + logger.debug("starting MCP client") + self.start() + self._tool_provider_started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._loaded_tools is None: + logger.debug("loading tools from MCP server") + self._loaded_tools = [] + pagination_token = None + page_count = 0 + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + # Use constructor defaults for prefix and filters in load_tools + paginated_tools = self.list_tools_sync( + pagination_token, prefix=self._prefix, tool_filters=self._tool_filters + ) + + # Tools are already filtered by list_tools_sync, so add them all + for tool in paginated_tools: + self._loaded_tools.append(tool) + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._loaded_tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._loaded_tools)) + + return self._loaded_tools + + def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + """ + self._consumers.add(consumer_id) + logger.debug("added provider consumer, count=%d", len(self._consumers)) + + def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + This method is idempotent - calling it multiple times with the same ID + has no additional effect after the first call. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + Uses existing synchronous stop() method for safe cleanup. + """ + self._consumers.discard(consumer_id) + logger.debug("removed provider consumer, count=%d", len(self._consumers)) + + if not self._consumers and self._tool_provider_started: + logger.debug("no consumers remaining, cleaning up") + try: + self.stop(None, None, None) # Existing sync method - safe for finalizers + self._tool_provider_started = False + self._loaded_tools = None + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # MCP-specific methods + def stop( self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: @@ -187,13 +322,28 @@ async def _set_close_event() -> None: self._background_thread_session = None self._background_thread_event_loop = None self._session_id = uuid.uuid4() + self._loaded_tools = None + self._tool_provider_started = False + self._consumers = set() - def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: + def list_tools_sync( + self, + pagination_token: str | None = None, + prefix: str | None = None, + tool_filters: ToolFilters | None = None, + ) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. This method calls the asynchronous list_tools method on the MCP session and adapts the returned tools to the AgentTool interface. + Args: + pagination_token: Optional token for pagination + prefix: Optional prefix to apply to tool names. If None, uses constructor default. + If explicitly provided (including empty string), overrides constructor default. + tool_filters: Optional filters to apply to tools. If None, uses constructor default. + If explicitly provided (including empty dict), overrides constructor default. + Returns: List[AgentTool]: A list of available tools adapted to the AgentTool interface """ @@ -201,13 +351,29 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + effective_prefix = self._prefix if prefix is None else prefix + effective_filters = self._tool_filters if tool_filters is None else tool_filters + async def _list_tools_async() -> ListToolsResult: return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) - mcp_tools = [MCPAgentTool(tool, self) for tool in list_tools_response.tools] + mcp_tools = [] + for tool in list_tools_response.tools: + # Apply prefix if specified + if effective_prefix: + prefixed_name = f"{effective_prefix}_{tool.name}" + mcp_tool = MCPAgentTool(tool, self, name_override=prefixed_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", tool.name, prefixed_name) + else: + mcp_tool = MCPAgentTool(tool, self) + + # Apply filters if specified + if self._should_include_tool_with_filters(mcp_tool, effective_filters): + mcp_tools.append(mcp_tool) + self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) @@ -530,5 +696,40 @@ def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures. raise MCPClientInitializationError("the client session was not initialized") return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on constructor filters.""" + return self._should_include_tool_with_filters(tool, self._tool_filters) + + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optional[ToolFilters]) -> bool: + """Check if a tool should be included based on provided filters.""" + if not filters: + return True + + # Apply allowed filter + if "allowed" in filters: + if not self._matches_patterns(tool, filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in filters: + if self._matches_patterns(tool, filters["rejected"]): + return False + + return True + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif isinstance(pattern, Pattern): + if pattern.match(tool.mcp_tool.name): + return True + elif isinstance(pattern, str): + if pattern == tool.mcp_tool.name: + return True + return False + def _is_session_active(self) -> bool: return self._background_thread is not None and self._background_thread.is_alive() diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 4f85d1168..c80b80f64 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -8,16 +8,19 @@ import logging import os import sys +import uuid import warnings from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence from typing_extensions import TypedDict, cast from strands.tools.decorator import DecoratedFunctionTool +from .._async import run_async +from ..experimental.tools import ToolProvider from ..types.tools import AgentTool, ToolSpec from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec @@ -36,6 +39,8 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None + self._tool_providers: List[ToolProvider] = [] + self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: """Process tools list. @@ -118,6 +123,20 @@ def add_tool(tool: Any) -> None: elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): for t in tool: add_tool(t) + + # Case 5: ToolProvider + elif isinstance(tool, ToolProvider): + self._tool_providers.append(tool) + tool.add_consumer(self._registry_id) + + async def get_tools() -> Sequence[AgentTool]: + return await tool.load_tools() + + provider_tools = run_async(get_tools) + + for provider_tool in provider_tools: + self.register_tool(provider_tool) + tool_names.append(provider_tool.tool_name) else: logger.warning("tool=<%s> | unrecognized tool specification", tool) @@ -655,3 +674,20 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e) return tools + + def cleanup(self, **kwargs: Any) -> None: + """Synchronously clean up all tool providers in this registry.""" + # Attempt cleanup of all providers even if one fails to minimize resource leakage + exceptions = [] + for provider in self._tool_providers: + try: + provider.remove_consumer(self._registry_id) + logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) + except Exception as e: + exceptions.append(e) + logger.error( + "provider=<%s>, error=<%s> | failed to remove provider consumer", type(provider).__name__, e + ) + + if exceptions: + raise exceptions[0] diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 5b17ba6e7..349c6b0de 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -77,14 +77,20 @@ class SessionException(Exception): pass +class ToolProviderException(Exception): + """Exception raised when a tool provider fails to load or cleanup tools.""" + + pass + + class StructuredOutputException(Exception): - """Exception raised when structured output validation fails after maximum retry attempts.""" + """Exception raised when structured output validation fails after maximum retry attempts.""" - def __init__(self, message: str): - """Initialize the exception with details about the failure. + def __init__(self, message: str): + """Initialize the exception with details about the failure. - Args: - message: The error message describing the structured output failure - """ - self.message = message - super().__init__(message) + Args: + message: The error message describing the structured output failure + """ + self.message = message + super().__init__(message) diff --git a/tests/fixtures/mock_agent_tool.py b/tests/fixtures/mock_agent_tool.py new file mode 100644 index 000000000..eed33731f --- /dev/null +++ b/tests/fixtures/mock_agent_tool.py @@ -0,0 +1,27 @@ +from typing import Any + +from strands.types.content import ToolUse +from strands.types.tools import AgentTool, ToolSpec + + +class MockAgentTool(AgentTool): + """Mock AgentTool implementation for testing.""" + + def __init__(self, name: str): + super().__init__() + self._tool_name = name + + @property + def tool_name(self) -> str: + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + return ToolSpec(name=self._tool_name, description="Mock tool", input_schema={}) + + @property + def tool_type(self) -> str: + return "mock" + + def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any): + yield f"Mock result for {self._tool_name}" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 9d490c0de..ec4b9deb7 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -894,10 +894,6 @@ def test_agent_tool_names(tools, agent): assert actual == expected -def test_agent__del__(agent): - del agent - - def test_agent_init_with_no_model_or_model_id(): agent = Agent() assert agent.model is not None @@ -2065,3 +2061,10 @@ def test_agent_tool_caller_interrupt(user): exp_message = r"cannot directly call tool during interrupt" with pytest.raises(RuntimeError, match=exp_message): agent.tool.test_tool() + + +def test_agent_del_before_tool_registry_set(): + """Test that Agent.__del__ doesn't fail if called before tool_registry is set.""" + agent = Agent() + del agent.tool_registry + agent.__del__() # Should not raise diff --git a/tests/strands/experimental/tools/__init__.py b/tests/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/test_async.py b/tests/strands/test_async.py new file mode 100644 index 000000000..2a98a953c --- /dev/null +++ b/tests/strands/test_async.py @@ -0,0 +1,25 @@ +"""Tests for _async module.""" + +import pytest + +from strands._async import run_async + + +def test_run_async_with_return_value(): + """Test run_async returns correct value.""" + + async def async_with_value(): + return 42 + + result = run_async(async_with_value) + assert result == 42 + + +def test_run_async_exception_propagation(): + """Test that exceptions are properly propagated.""" + + async def async_with_exception(): + raise ValueError("test exception") + + with pytest.raises(ValueError, match="test exception"): + run_async(async_with_exception) diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py new file mode 100644 index 000000000..9cb90167d --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -0,0 +1,826 @@ +"""Unit tests for MCPClient ToolProvider functionality.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest +from mcp.types import Tool as MCPTool + +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.tools.mcp.mcp_client import ToolFilters +from strands.types import PaginatedList +from strands.types.exceptions import ToolProviderException + + +@pytest.fixture +def mock_transport(): + """Create a mock transport callable.""" + + def transport(): + read_stream = MagicMock() + write_stream = MagicMock() + return read_stream, write_stream + + return transport + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + return tool + + +@pytest.fixture +def mock_agent_tool(mock_mcp_tool): + """Create a mock MCPAgentTool.""" + agent_tool = MagicMock(spec=MCPAgentTool) + agent_tool.tool_name = "test_tool" + agent_tool.mcp_tool = mock_mcp_tool + return agent_tool + + +def create_mock_tool(tool_name: str, mcp_tool_name: str | None = None) -> MagicMock: + """Helper to create mock tools with specific names.""" + tool = MagicMock(spec=MCPAgentTool) + tool.tool_name = tool_name + tool.tool_spec = { + "name": tool_name, + "description": f"Description for {tool_name}", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + tool.mcp_tool = MagicMock(spec=MCPTool) + tool.mcp_tool.name = mcp_tool_name or tool_name + tool.mcp_tool.description = f"Description for {tool_name}" + return tool + + +def test_init_with_tool_filters_and_prefix(mock_transport): + """Test initialization with tool filters and prefix.""" + filters = {"allowed": ["tool1"]} + prefix = "test_prefix" + + client = MCPClient(mock_transport, tool_filters=filters, prefix=prefix) + + assert client._tool_filters == filters + assert client._prefix == prefix + assert client._loaded_tools is None + assert client._tool_provider_started is False + + +@pytest.mark.asyncio +async def test_load_tools_starts_client_when_not_started(mock_transport, mock_agent_tool): + """Test that load_tools starts the client when not already started.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_called_once() + assert client._tool_provider_started is True + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_does_not_start_client_when_already_started(mock_transport, mock_agent_tool): + """Test that load_tools does not start client when already started.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_not_called() + assert len(tools) == 1 + + +@pytest.mark.asyncio +async def test_load_tools_raises_exception_on_client_start_failure(mock_transport): + """Test that load_tools raises ToolProviderException when client start fails.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start: + mock_start.side_effect = Exception("Client start failed") + + with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): + await client.load_tools() + + +@pytest.mark.asyncio +async def test_load_tools_caches_tools(mock_transport, mock_agent_tool): + """Test that load_tools caches tools and doesn't reload them.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + # First call + tools1 = await client.load_tools() + # Second call + tools2 = await client.load_tools() + + # Client should only be called once + mock_list_tools.assert_called_once() + assert tools1 is tools2 + + +@pytest.mark.asyncio +async def test_load_tools_handles_pagination(mock_transport): + """Test that load_tools handles pagination correctly.""" + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token + mock_list_tools.side_effect = [ + PaginatedList([tool1], token="page2"), + PaginatedList([tool2], token=None), + ] + + tools = await client.load_tools() + + # Should have called list_tools_sync twice + assert mock_list_tools.call_count == 2 + # First call with no token, second call with "page2" token + mock_list_tools.assert_any_call(None, prefix=None, tool_filters=None) + mock_list_tools.assert_any_call("page2", prefix=None, tool_filters=None) + + assert len(tools) == 2 + assert tools[0] is tool1 + assert tools[1] is tool2 + + +@pytest.mark.asyncio +async def test_allowed_filter_string_match(mock_transport): + """Test allowed filter with string matching.""" + tool1 = create_mock_tool("allowed_tool") + + filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results (simulating the filtering) + mock_list_tools.return_value = PaginatedList([tool1]) # Only allowed tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "allowed_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_regex_match(mock_transport): + """Test allowed filter with regex matching.""" + tool1 = create_mock_tool("echo_tool") + + filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only echo tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "echo_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_callable_match(mock_transport): + """Test allowed filter with callable matching.""" + tool1 = create_mock_tool("short") + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 10 + + filters: ToolFilters = {"allowed": [short_names_only]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only short tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "short" + + +@pytest.mark.asyncio +async def test_rejected_filter_string_match(mock_transport): + """Test rejected filter with string matching.""" + tool1 = create_mock_tool("good_tool") + + filters: ToolFilters = {"rejected": ["bad_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only good tool + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "good_tool" + + +@pytest.mark.asyncio +async def test_prefix_renames_tools(mock_transport): + """Test that prefix properly renames tools.""" + # Create a mock MCP tool (not MCPAgentTool) + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_name" + + client = MCPClient(mock_transport, prefix="prefix") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call list_tools_sync directly to test prefix functionality + result = client.list_tools_sync(prefix="prefix") + + # Should create MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="prefix_original_name") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_add_consumer(mock_transport): + """Test adding a provider consumer.""" + client = MCPClient(mock_transport) + + client.add_consumer("consumer1") + + assert "consumer1" in client._consumers + assert len(client._consumers) == 1 + + +def test_remove_consumer_without_cleanup(mock_transport): + """Test removing a provider consumer without triggering cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._consumers.add("consumer2") + client._tool_provider_started = True + + client.remove_consumer("consumer1") + + assert "consumer1" not in client._consumers + assert "consumer2" in client._consumers + assert client._tool_provider_started is True # Should not cleanup yet + + +def test_remove_consumer_with_cleanup(mock_transport): + """Test removing the last provider consumer triggers cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + client._loaded_tools = [MagicMock()] + + with patch.object(client, "stop") as mock_stop: + client.remove_consumer("consumer1") + + assert len(client._consumers) == 0 + assert client._tool_provider_started is False + assert client._loaded_tools is None + mock_stop.assert_called_once_with(None, None, None) + + +def test_remove_consumer_cleanup_failure(mock_transport): + """Test that remove_consumer raises ToolProviderException when cleanup fails.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + + with patch.object(client, "stop") as mock_stop: + mock_stop.side_effect = Exception("Cleanup failed") + + with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): + client.remove_consumer("consumer1") + + +def test_mcp_client_reuse_across_multiple_agents(mock_transport): + """Test that a single MCPClient can be used across multiple agents.""" + from strands import Agent + + tool1 = create_mock_tool(tool_name="shared_echo", mcp_tool_name="echo") + client = MCPClient(mock_transport, tool_filters={"allowed": ["echo"]}, prefix="shared") + + with ( + patch.object(client, "list_tools_sync") as mock_list_tools, + patch.object(client, "start") as mock_start, + patch.object(client, "stop") as mock_stop, + ): + mock_list_tools.return_value = PaginatedList([tool1]) + + # Create two agents with the same client + agent_1 = Agent(tools=[client]) + agent_2 = Agent(tools=[client]) + + # Both agents should have the same tool + assert "shared_echo" in agent_1.tool_names + assert "shared_echo" in agent_2.tool_names + assert agent_1.tool_names == agent_2.tool_names + + # Client should only be started once + mock_start.assert_called_once() + + # First agent cleanup - client should remain active + agent_1.cleanup() + mock_stop.assert_not_called() # Should not stop yet + + # Second agent should still work + assert "shared_echo" in agent_2.tool_names + + # Final cleanup when last agent is removed + agent_2.cleanup() + mock_stop.assert_called_once() # Now it should stop + + +def test_list_tools_sync_prefix_override_constructor_default(mock_transport): + """Test that list_tools_sync can override constructor prefix.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "override_original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with override prefix + result = client.list_tools_sync(prefix="override") + + # Should use override prefix, not constructor prefix + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="override_original_tool") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_prefix_override_with_empty_string(mock_transport): + """Test that list_tools_sync can override constructor prefix with empty string.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with empty string prefix (should override constructor default) + result = client.list_tools_sync(prefix="") + + # Should use no prefix (empty string overrides constructor) + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client) + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_prefix_uses_constructor_default_when_none(mock_transport): + """Test that list_tools_sync uses constructor prefix when None is passed.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "constructor_original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with None prefix (should use constructor default) + result = client.list_tools_sync(prefix=None) + + # Should use constructor prefix + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="constructor_original_tool") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_tool_filters_override_constructor_default(mock_transport): + """Test that list_tools_sync can override constructor tool_filters.""" + # Create mock tools + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + # Client with constructor filters that would allow both + constructor_filters: ToolFilters = {"allowed": ["allowed_tool", "rejected_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="allowed_tool"), MagicMock(name="rejected_tool")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Override filters to only allow one tool + override_filters: ToolFilters = {"allowed": ["allowed_tool"]} + result = client.list_tools_sync(tool_filters=override_filters) + + # Should only include the allowed tool based on override filters + assert len(result) == 1 + assert result[0] is tool1 + + +def test_list_tools_sync_tool_filters_override_with_empty_dict(mock_transport): + """Test that list_tools_sync can override constructor filters with empty dict.""" + # Create mock tools + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + # Client with constructor filters that would reject tools + constructor_filters: ToolFilters = {"rejected": ["tool1", "tool2"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="tool1"), MagicMock(name="tool2")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Override with empty filters (should allow all tools) + result = client.list_tools_sync(tool_filters={}) + + # Should include both tools since empty filters allow everything + assert len(result) == 2 + assert result[0] is tool1 + assert result[1] is tool2 + + +def test_list_tools_sync_tool_filters_uses_constructor_default_when_none(mock_transport): + """Test that list_tools_sync uses constructor filters when None is passed.""" + # Create mock tools + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + # Client with constructor filters + constructor_filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="allowed_tool"), MagicMock(name="rejected_tool")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Call with None filters (should use constructor default) + result = client.list_tools_sync(tool_filters=None) + + # Should only include allowed tool based on constructor filters + assert len(result) == 1 + assert result[0] is tool1 + + +def test_list_tools_sync_combined_prefix_and_filter_overrides(mock_transport): + """Test that list_tools_sync can override both prefix and filters simultaneously.""" + # Client with constructor defaults + constructor_filters: ToolFilters = {"allowed": ["echo_tool", "other_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters, prefix="constructor") + + # Create mock tools + mock_echo_tool = MagicMock() + mock_echo_tool.name = "echo_tool" + mock_other_tool = MagicMock() + mock_other_tool.name = "other_tool" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_echo_tool, mock_other_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_echo_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_other_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Override both prefix and filters + override_filters: ToolFilters = {"allowed": ["echo_tool"]} + result = client.list_tools_sync(prefix="override", tool_filters=override_filters) + + # Verify prefix override: should use "override" not "constructor" + calls = mock_agent_tool_class.call_args_list + assert len(calls) == 2 + + # First tool should have override prefix + args1, kwargs1 = calls[0] + assert args1 == (mock_echo_tool, client) + assert kwargs1 == {"name_override": "override_echo_tool"} + + # Second tool should have override prefix + args2, kwargs2 = calls[1] + assert args2 == (mock_other_tool, client) + assert kwargs2 == {"name_override": "override_other_tool"} + + # Verify filter override: should only include echo_tool based on override filters + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_direct_usage_without_constructor_defaults(mock_transport): + """Test direct usage of list_tools_sync without constructor defaults.""" + # Client without constructor defaults + client = MCPClient(mock_transport) + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_tool1, mock_tool2] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_tool1 + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_tool2 + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Direct usage with explicit parameters + filters: ToolFilters = {"allowed": ["tool1"]} + result = client.list_tools_sync(prefix="direct", tool_filters=filters) + + # Verify prefix is applied + calls = mock_agent_tool_class.call_args_list + assert len(calls) == 2 + + # Should create tools with direct prefix + args1, kwargs1 = calls[0] + assert args1 == (mock_tool1, client) + assert kwargs1 == {"name_override": "direct_tool1"} + + args2, kwargs2 = calls[1] + assert args2 == (mock_tool2, client) + assert kwargs2 == {"name_override": "direct_tool2"} + + # Verify filtering: should only include tool1 + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_regex_filter_override(mock_transport): + """Test list_tools_sync with regex filter override.""" + # Client without constructor filters + client = MCPClient(mock_transport) + + # Create mock tools + mock_echo_tool = MagicMock() + mock_echo_tool.name = "echo_command" + mock_list_tool = MagicMock() + mock_list_tool.name = "list_files" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_echo_tool, mock_list_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_echo_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_list_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Use regex filter to match only echo tools + regex_filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + result = client.list_tools_sync(tool_filters=regex_filters) + + # Should create both tools + assert mock_agent_tool_class.call_count == 2 + + # Should only include echo tool (regex matches "echo_command") + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_callable_filter_override(mock_transport): + """Test list_tools_sync with callable filter override.""" + # Client without constructor filters + client = MCPClient(mock_transport) + + # Create mock tools + mock_short_tool = MagicMock() + mock_short_tool.name = "short" + mock_long_tool = MagicMock() + mock_long_tool.name = "very_long_tool_name" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_short_tool, mock_long_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_short_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_long_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Use callable filter for short names only + def short_names_only(tool) -> bool: + return len(tool.mcp_tool.name) <= 10 + + callable_filters: ToolFilters = {"allowed": [short_names_only]} + result = client.list_tools_sync(tool_filters=callable_filters) + + # Should create both tools + assert mock_agent_tool_class.call_count == 2 + + # Should only include short tool (name length <= 10) + assert len(result) == 1 + assert result[0] is mock_agent_tool1 diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py index 2c730624e..85d533403 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -340,6 +340,21 @@ def __getattr__(self, name): class TestMCPInstrumentation: + def test_mcp_instrumentation_called_on_client_init(self): + """Test that mcp_instrumentation is called when MCPClient is initialized.""" + with patch("strands.tools.mcp.mcp_client.mcp_instrumentation") as mock_instrumentation: + # Mock transport + def mock_transport(): + read_stream = AsyncMock() + write_stream = AsyncMock() + return read_stream, write_stream + + # Create MCPClient instance - should call mcp_instrumentation + MCPClient(mock_transport) + + # Verify mcp_instrumentation was called + mock_instrumentation.assert_called_once() + def test_mcp_instrumentation_idempotent_with_multiple_clients(self): """Test that mcp_instrumentation is only called once even with multiple MCPClient instances.""" diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ee0098adc..1bd4ef13f 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -260,3 +260,148 @@ def test_register_strands_tools_module_non_callable_function(): " Tool tool_with_spec_but_non_callable_function function is not callable", ): tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_non_callable_function"]) + + +def test_tool_registry_cleanup_with_mcp_client(): + """Test that ToolRegistry cleanup properly handles MCP clients without orphaning threads.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.tools.mcp import MCPClient + + # Create a mock MCP client that simulates a real tool provider + mock_transport = MagicMock() + mock_client = MCPClient(mock_transport) + + # Mock the client to avoid actual network operations + mock_client.load_tools = AsyncMock(return_value=[]) + + registry = ToolRegistry() + + # Use process_tools to properly register the client + registry.process_tools([mock_client]) + + # Verify the client was registered as a consumer + assert registry._registry_id in mock_client._consumers + + # Test cleanup calls remove_consumer + registry.cleanup() + + # Verify cleanup was attempted + assert registry._registry_id not in mock_client._consumers + + +def test_tool_registry_cleanup_exception_handling(): + """Test that ToolRegistry cleanup attempts all providers even if some fail.""" + from unittest.mock import MagicMock + + # Create mock providers - one that fails, one that succeeds + failing_provider = MagicMock() + failing_provider.remove_consumer.side_effect = Exception("Cleanup failed") + + working_provider = MagicMock() + + registry = ToolRegistry() + registry._tool_providers = [failing_provider, working_provider] + + # Cleanup should attempt both providers and raise the first exception + with pytest.raises(Exception, match="Cleanup failed"): + registry.cleanup() + + # Verify both providers were attempted + failing_provider.remove_consumer.assert_called_once() + working_provider.remove_consumer.assert_called_once() + + +def test_tool_registry_cleanup_idempotent(): + """Test that ToolRegistry cleanup is idempotent.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.experimental.tools import ToolProvider + + provider = MagicMock(spec=ToolProvider) + provider.load_tools = AsyncMock(return_value=[]) + + registry = ToolRegistry() + + # Use process_tools to properly register the provider + registry.process_tools([provider]) + + # First cleanup should call remove_consumer + registry.cleanup() + provider.remove_consumer.assert_called_once_with(registry._registry_id) + + # Reset mock call count + provider.remove_consumer.reset_mock() + + # Second cleanup should call remove_consumer again (not idempotent yet) + # This test documents current behavior - registry cleanup is not idempotent + registry.cleanup() + provider.remove_consumer.assert_called_once_with(registry._registry_id) + + +def test_tool_registry_process_tools_exception_after_add_consumer(): + """Test that tool provider is still tracked for cleanup even if load_tools fails.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.experimental.tools import ToolProvider + + # Create a mock tool provider that fails during load_tools + mock_provider = MagicMock(spec=ToolProvider) + mock_provider.add_consumer = MagicMock() + mock_provider.remove_consumer = MagicMock() + + async def failing_load_tools(): + raise Exception("Failed to load tools") + + mock_provider.load_tools = AsyncMock(side_effect=failing_load_tools) + + registry = ToolRegistry() + + # Processing should fail but provider should still be tracked + with pytest.raises(ValueError, match="Failed to load tool"): + registry.process_tools([mock_provider]) + + # Verify provider was added to registry for cleanup tracking + assert mock_provider in registry._tool_providers + + # Verify add_consumer was called before the failure + mock_provider.add_consumer.assert_called_once_with(registry._registry_id) + + # Cleanup should still work + registry.cleanup() + mock_provider.remove_consumer.assert_called_once_with(registry._registry_id) + + +def test_tool_registry_add_consumer_before_load_tools(): + """Test that add_consumer is called before load_tools to ensure cleanup tracking.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.experimental.tools import ToolProvider + + # Create a mock tool provider that tracks call order + mock_provider = MagicMock(spec=ToolProvider) + call_order = [] + + def track_add_consumer(*args, **kwargs): + call_order.append("add_consumer") + + async def track_load_tools(*args, **kwargs): + call_order.append("load_tools") + return [] + + mock_provider.add_consumer.side_effect = track_add_consumer + mock_provider.load_tools = AsyncMock(side_effect=track_load_tools) + + registry = ToolRegistry() + + # Process the tool provider + registry.process_tools([mock_provider]) + + # Verify add_consumer was called before load_tools + assert call_order == ["add_consumer", "load_tools"] + + # Verify the provider was added to the registry for cleanup + assert mock_provider in registry._tool_providers + + # Verify add_consumer was called with the registry ID + mock_provider.add_consumer.assert_called_once_with(registry._registry_id) diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py new file mode 100644 index 000000000..fdf4abb0a --- /dev/null +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -0,0 +1,328 @@ +"""Unit tests for ToolRegistry ToolProvider functionality.""" + +from unittest.mock import patch + +import pytest + +from strands.experimental.tools.tool_provider import ToolProvider +from strands.tools.registry import ToolRegistry +from tests.fixtures.mock_agent_tool import MockAgentTool + + +class MockToolProvider(ToolProvider): + """Mock ToolProvider for testing.""" + + def __init__(self, tools=None, cleanup_error=None): + self._tools = tools or [] + self._cleanup_error = cleanup_error + self.cleanup_called = False + self.remove_consumer_called = False + self.remove_consumer_id = None + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + return self._tools + + def cleanup(self): + self.cleanup_called = True + if self._cleanup_error: + raise self._cleanup_error + + def add_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + def remove_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + if self._cleanup_error: + raise self._cleanup_error + + +@pytest.fixture +def mock_run_async(): + """Fixture for mocking strands.tools.registry.run_async.""" + with patch("strands.tools.registry.run_async") as mock: + yield mock + + +@pytest.fixture +def mock_agent_tool(): + """Fixture factory for creating MockAgentTool instances.""" + return MockAgentTool + + +class TestToolRegistryToolProvider: + """Test ToolRegistry integration with ToolProvider.""" + + def test_process_tools_with_tool_provider(self, mock_run_async, mock_agent_tool): + """Test that process_tools handles ToolProvider correctly.""" + # Create mock tools + mock_tool1 = mock_agent_tool("provider_tool_1") + mock_tool2 = mock_agent_tool("provider_tool_2") + + # Create mock provider + provider = MockToolProvider([mock_tool1, mock_tool2]) + + registry = ToolRegistry() + + # Mock run_async to return the tools directly + mock_run_async.return_value = [mock_tool1, mock_tool2] + + tool_names = registry.process_tools([provider]) + + # Verify run_async was called with the provider's load_tools method + mock_run_async.assert_called_once() + + # Verify tools were registered + assert "provider_tool_1" in tool_names + assert "provider_tool_2" in tool_names + assert len(tool_names) == 2 + + # Verify provider was tracked + assert provider in registry._tool_providers + + # Verify tools are in registry + assert registry.registry["provider_tool_1"] is mock_tool1 + assert registry.registry["provider_tool_2"] is mock_tool2 + + def test_process_tools_with_multiple_providers(self, mock_run_async, mock_agent_tool): + """Test that process_tools handles multiple ToolProviders.""" + # Create mock tools for first provider + mock_tool1 = mock_agent_tool("provider1_tool") + provider1 = MockToolProvider([mock_tool1]) + + # Create mock tools for second provider + mock_tool2 = mock_agent_tool("provider2_tool") + provider2 = MockToolProvider([mock_tool2]) + + registry = ToolRegistry() + + # Mock run_async to return appropriate tools for each call + mock_run_async.side_effect = [[mock_tool1], [mock_tool2]] + + tool_names = registry.process_tools([provider1, provider2]) + + # Verify run_async was called twice + assert mock_run_async.call_count == 2 + + # Verify all tools were registered + assert "provider1_tool" in tool_names + assert "provider2_tool" in tool_names + assert len(tool_names) == 2 + + # Verify both providers were tracked + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers + assert len(registry._tool_providers) == 2 + + def test_process_tools_with_mixed_tools_and_providers(self, mock_run_async, mock_agent_tool): + """Test that process_tools handles mix of regular tools and providers.""" + # Create regular tool + regular_tool = mock_agent_tool("regular_tool") + + # Create provider tool + provider_tool = mock_agent_tool("provider_tool") + provider = MockToolProvider([provider_tool]) + + registry = ToolRegistry() + + mock_run_async.return_value = [provider_tool] + + tool_names = registry.process_tools([regular_tool, provider]) + + # Verify both tools were registered + assert "regular_tool" in tool_names + assert "provider_tool" in tool_names + assert len(tool_names) == 2 + + # Verify only provider was tracked + assert provider in registry._tool_providers + assert len(registry._tool_providers) == 1 + + def test_process_tools_with_empty_provider(self, mock_run_async): + """Test that process_tools handles provider with no tools.""" + provider = MockToolProvider([]) # Empty tools list + + registry = ToolRegistry() + + mock_run_async.return_value = [] + + tool_names = registry.process_tools([provider]) + + # Verify no tools were registered + assert not tool_names + + # Verify provider was still tracked + assert provider in registry._tool_providers + + def test_tool_providers_public_access(self): + """Test that tool_providers can be accessed directly.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Verify direct access works + assert len(registry._tool_providers) == 2 + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers + + def test_tool_providers_empty_by_default(self): + """Test that tool_providers is empty by default.""" + registry = ToolRegistry() + + assert not registry._tool_providers + assert isinstance(registry._tool_providers, list) + + def test_process_tools_provider_load_exception(self, mock_run_async): + """Test that process_tools handles exceptions from provider.load_tools().""" + provider = MockToolProvider() + + registry = ToolRegistry() + + # Make load_tools raise an exception + mock_run_async.side_effect = Exception("Load tools failed") + + # Should raise the exception from load_tools + with pytest.raises(Exception, match="Load tools failed"): + registry.process_tools([provider]) + + # Provider should still be tracked even if load_tools failed + assert provider in registry._tool_providers + + def test_tool_provider_tracking_persistence(self, mock_run_async, mock_agent_tool): + """Test that tool providers are tracked across multiple process_tools calls.""" + provider1 = MockToolProvider([mock_agent_tool("tool1")]) + provider2 = MockToolProvider([mock_agent_tool("tool2")]) + + registry = ToolRegistry() + + mock_run_async.side_effect = [ + [mock_agent_tool("tool1")], + [mock_agent_tool("tool2")], + ] + + # Process first provider + registry.process_tools([provider1]) + assert len(registry._tool_providers) == 1 + assert provider1 in registry._tool_providers + + # Process second provider + registry.process_tools([provider2]) + assert len(registry._tool_providers) == 2 + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers + + def test_process_tools_provider_async_optimization(self, mock_agent_tool): + """Test that load_tools and add_consumer are called in same async context.""" + mock_tool = mock_agent_tool("test_tool") + + class TestProvider(ToolProvider): + def __init__(self): + self.load_tools_called = False + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + self.load_tools_called = True + return [mock_tool] + + def add_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + def remove_consumer(self, consumer_id): + pass + + provider = TestProvider() + registry = ToolRegistry() + + # Process the provider - this should call both methods + tool_names = registry.process_tools([provider]) + + # Verify both methods were called + assert provider.load_tools_called + assert provider.add_consumer_called + assert provider.add_consumer_id == registry._registry_id + + # Verify tool was registered + assert "test_tool" in tool_names + assert provider in registry._tool_providers + + def test_registry_cleanup(self): + """Test that registry cleanup calls remove_consumer on all providers.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + registry.cleanup() + + # Verify both providers had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + def test_registry_cleanup_with_provider_consumer_removal(self): + """Test that cleanup removes provider consumers correctly.""" + + class TestProvider(ToolProvider): + def __init__(self): + self.remove_consumer_called = False + self.remove_consumer_id = None + + async def load_tools(self): + return [] + + def add_consumer(self, consumer_id): + pass + + def remove_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + + provider = TestProvider() + registry = ToolRegistry() + registry._tool_providers = [provider] + + # Call cleanup + registry.cleanup() + + # Verify remove_consumer was called with correct ID + assert provider.remove_consumer_called + assert provider.remove_consumer_id == registry._registry_id + + def test_registry_cleanup_raises_exception_on_provider_error(self): + """Test that cleanup raises exception when provider removal fails.""" + provider1 = MockToolProvider(cleanup_error=RuntimeError("Provider cleanup failed")) + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Cleanup should raise the exception from first provider but still attempt cleanup of all + with pytest.raises(RuntimeError, match="Provider cleanup failed"): + registry.cleanup() + + # Both providers should have had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + def test_registry_cleanup_raises_first_exception_on_multiple_provider_errors(self): + """Test that cleanup raises first exception when multiple providers fail but attempts all.""" + provider1 = MockToolProvider(cleanup_error=RuntimeError("Provider 1 failed")) + provider2 = MockToolProvider(cleanup_error=ValueError("Provider 2 failed")) + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Cleanup should raise first exception but still attempt cleanup of all + with pytest.raises(RuntimeError, match="Provider 1 failed"): + registry.cleanup() + + # Both providers should have had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py new file mode 100644 index 000000000..7914bb326 --- /dev/null +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -0,0 +1,160 @@ +"""Integration tests for MCPClient ToolProvider functionality with real MCP server.""" + +import logging +import re + +import pytest +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_client import ToolFilters + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +def test_mcp_client_tool_provider_filters(): + """Test MCPClient with various filter combinations.""" + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 20 + + filters: ToolFilters = { + "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], + "rejected": ["echo_with_delay"], + } + + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="test", + ) + + agent = Agent(tools=[client]) + tool_names = agent.tool_names + + assert "test_echo_with_delay" not in [name for name in tool_names] + assert all(name.startswith("test_") for name in tool_names) + + agent.cleanup() + + +def test_mcp_client_tool_provider_execution(): + """Test that MCPClient works with agent execution.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="filtered", + ) + + agent = Agent(tools=[client]) + + assert "filtered_echo" in agent.tool_names + + tool_result = agent.tool.filtered_echo(to_echo="Hello World") + assert "Hello World" in str(tool_result) + + result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") + assert "Integration Test" in str(result) + + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 + + agent.cleanup() + + +def test_mcp_client_tool_provider_reuse(): + """Test that a single MCPClient can be used across multiple agents.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="shared", + ) + + agent1 = Agent(tools=[client]) + assert "shared_echo" in agent1.tool_names + + result1 = agent1.tool.shared_echo(to_echo="Agent 1") + assert "Agent 1" in str(result1) + + agent2 = Agent(tools=[client]) + assert "shared_echo" in agent2.tool_names + + result2 = agent2.tool.shared_echo(to_echo="Agent 2") + assert "Agent 2" in str(result2) + + assert len(agent1.tool_names) == len(agent2.tool_names) + assert agent1.tool_names == agent2.tool_names + + agent1.cleanup() + + # Agent 1 cleans up - client should still be active for agent 2 + agent1.cleanup() + + # Agent 2 should still be able to use the tool + result2 = agent2.tool.shared_echo(to_echo="Agent 2 Test") + assert "Agent 2 Test" in str(result2) + + agent2.cleanup() + + +def test_mcp_client_multiple_servers(): + """Test MCPClient with multiple MCP servers simultaneously.""" + client1 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo"]}, + prefix="server1", + ) + client2 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo_with_structured_content"]}, + prefix="server2", + ) + + agent = Agent(tools=[client1, client2]) + + assert "server1_echo" in agent.tool_names + assert "server2_echo_with_structured_content" in agent.tool_names + assert len(agent.tool_names) == 2 + + result1 = agent.tool.server1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup() + + +def test_mcp_client_server_startup_failure(): + """Test that MCPClient handles server startup failure gracefully without hanging.""" + from strands.types.exceptions import ToolProviderException + + failing_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), + startup_timeout=2, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[failing_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) + + +def test_mcp_client_server_connection_timeout(): + """Test that MCPClient times out gracefully when server hangs during startup.""" + from strands.types.exceptions import ToolProviderException + + hanging_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), + startup_timeout=1, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[hanging_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException)