diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index eb9bc4dd9..12c1f8376 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,7 +4,7 @@ """ from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any, Sequence, cast from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics @@ -46,3 +46,34 @@ def __str__(self) -> str: if isinstance(item, dict) and "text" in item: result += item.get("text", "") + "\n" return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AgentResult": + """Rehydrate an AgentResult from persisted JSON. + + Args: + data: Dictionary containing the serialized AgentResult data + Returns: + AgentResult instance + Raises: + TypeError: If the data format is invalid@ + """ + if data.get("type") != "agent_result": + raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}") + + message = cast(Message, data.get("message")) + stop_reason = cast(StopReason, data.get("stop_reason")) + + return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + + def to_dict(self) -> dict[str, Any]: + """Convert this AgentResult to JSON-serializable dictionary. + + Returns: + Dictionary containing serialized AgentResult data + """ + return { + "type": "agent_result", + "message": self.message, + "stop_reason": self.stop_reason, + } diff --git a/src/strands/experimental/hooks/multiagent/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..d059d0da5 --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -0,0 +1,20 @@ +"""Multi-agent hook events and utilities. + +Provides event classes for hooking into multi-agent orchestrator lifecycle. +""" + +from .events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) + +__all__ = [ + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", +] diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py new file mode 100644 index 000000000..9e54296a4 --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -0,0 +1,93 @@ +"""Multi-agent execution lifecycle events for hook system integration. + +These events are fired by orchestrators (Graph/Swarm) at key points so +hooks can persist, monitor, or debug execution. No intermediate state model +is used—hooks read from the orchestrator directly. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from ....hooks import BaseHookEvent + +if TYPE_CHECKING: + from ....multiagent.base import MultiAgentBase + + +@dataclass +class MultiAgentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class BeforeNodeCallEvent(BaseHookEvent): + """Event triggered before individual node execution starts. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node about to execute + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterNodeCallEvent(BaseHookEvent): + """Event triggered after individual node execution completes. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered before orchestrator execution starts. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 0dbd85d81..07e63577d 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -4,6 +4,7 @@ """ import asyncio +import logging import warnings from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -15,6 +16,8 @@ from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage +logger = logging.getLogger(__name__) + class Status(Enum): """Execution status for both graphs and nodes.""" @@ -59,6 +62,54 @@ def get_agent_results(self) -> list[AgentResult]: flattened.extend(nested_node_result.get_agent_results()) return flattened + def to_dict(self) -> dict[str, Any]: + """Convert NodeResult to JSON-serializable dict, ignoring state field.""" + if isinstance(self.result, Exception): + result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} + elif isinstance(self.result, AgentResult): + result_data = self.result.to_dict() + else: + # MultiAgentResult case + result_data = self.result.to_dict() + + return { + "result": result_data, + "execution_time": self.execution_time, + "status": self.status.value, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeResult": + """Rehydrate a NodeResult from persisted JSON.""" + if "result" not in data: + raise TypeError("NodeResult.from_dict: missing 'result'") + raw = data["result"] + + result: Union[AgentResult, "MultiAgentResult", Exception] + if isinstance(raw, dict) and raw.get("type") == "agent_result": + result = AgentResult.from_dict(raw) + elif isinstance(raw, dict) and raw.get("type") == "exception": + result = Exception(str(raw.get("message", "node failed"))) + elif isinstance(raw, dict) and raw.get("type") == "multiagent_result": + result = MultiAgentResult.from_dict(raw) + else: + raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") + + usage = _parse_usage(data.get("accumulated_usage", {})) + metrics = _parse_metrics(data.get("accumulated_metrics", {})) + + return cls( + result=result, + execution_time=int(data.get("execution_time", 0)), + status=Status(data.get("status", "pending")), + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + ) + @dataclass class MultiAgentResult: @@ -76,6 +127,38 @@ class MultiAgentResult: execution_count: int = 0 execution_time: int = 0 + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": + """Rehydrate a MultiAgentResult from persisted JSON.""" + if data.get("type") != "multiagent_result": + raise TypeError(f"MultiAgentResult.from_dict: unexpected type {data.get('type')!r}") + + results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} + usage = _parse_usage(data.get("accumulated_usage", {})) + metrics = _parse_metrics(data.get("accumulated_metrics", {})) + + multiagent_result = cls( + status=Status(data.get("status", Status.PENDING.value)), + results=results, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + execution_time=int(data.get("execution_time", 0)), + ) + return multiagent_result + + def to_dict(self) -> dict[str, Any]: + """Convert MultiAgentResult to JSON-serializable dict.""" + return { + "type": "multiagent_result", + "status": self.status.value, + "results": {k: v.to_dict() for k, v in self.results.items()}, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, + "execution_time": self.execution_time, + } + class MultiAgentBase(ABC): """Base class for multi-agent helpers. @@ -122,3 +205,34 @@ def execute() -> MultiAgentResult: with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() + + def serialize_state(self) -> dict[str, Any]: + """Return a JSON-serializable snapshot of the orchestrator state.""" + raise NotImplementedError + + def deserialize_state(self, payload: dict[str, Any]) -> None: + """Restore orchestrator state from a session dict.""" + raise NotImplementedError + + +# Private helper function to avoid duplicate code + + +def _parse_usage(usage_data: dict[str, Any]) -> Usage: + """Parse Usage from dict data.""" + usage = Usage( + inputTokens=usage_data.get("inputTokens", 0), + outputTokens=usage_data.get("outputTokens", 0), + totalTokens=usage_data.get("totalTokens", 0), + ) + # Add optional fields if they exist + if "cacheReadInputTokens" in usage_data: + usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] + if "cacheWriteInputTokens" in usage_data: + usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] + return usage + + +def _parse_metrics(metrics_data: dict[str, Any]) -> Metrics: + """Parse Metrics from dict data.""" + return Metrics(latencyMs=metrics_data.get("latencyMs", 0)) diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py new file mode 100644 index 000000000..727d28a48 --- /dev/null +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -0,0 +1,41 @@ +from typing import Iterator, Literal, Tuple, Type + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import ( + HookEvent, + HookProvider, + HookRegistry, +) + + +class MockMultiAgentHookProvider(HookProvider): + def __init__(self, event_types: list[Type] | Literal["all"]): + if event_types == "all": + event_types = [ + MultiAgentInitializedEvent, + BeforeNodeCallEvent, + AfterNodeCallEvent, + AfterMultiAgentInvocationEvent, + ] + + self.events_received = [] + self.events_types = event_types + + @property + def event_types_received(self): + return [type(event) for event in self.events_received] + + def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + return len(self.events_received), iter(self.events_received) + + def register_hooks(self, registry: HookRegistry) -> None: + for event_type in self.events_types: + registry.add_callback(event_type, self.add_event) + + def add_event(self, event: HookEvent) -> None: + self.events_received.append(event) diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 409b08a2d..67a7f2458 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -95,3 +95,48 @@ def test__str__non_dict_content(mock_metrics): message_string = str(result) assert message_string == "Valid text\nMore valid text\n" + + +def test_to_dict(mock_metrics, simple_message: Message): + """Test that to_dict serializes AgentResult correctly.""" + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={"key": "value"}) + + data = result.to_dict() + + assert data == { + "type": "agent_result", + "message": simple_message, + "stop_reason": "end_turn", + } + + +def test_from_dict(): + """Test that from_dict works with valid data.""" + data = { + "type": "agent_result", + "message": {"role": "assistant", "content": [{"text": "Test response"}]}, + "stop_reason": "end_turn", + } + + result = AgentResult.from_dict(data) + + assert result.message == data["message"] + assert result.stop_reason == data["stop_reason"] + assert isinstance(result.metrics, EventLoopMetrics) + assert result.state == {} + + +def test_roundtrip_serialization(mock_metrics, complex_message: Message): + """Test that to_dict() and from_dict() work together correctly.""" + original = AgentResult( + stop_reason="max_tokens", message=complex_message, metrics=mock_metrics, state={"test": "data"} + ) + + # Serialize and deserialize + data = original.to_dict() + restored = AgentResult.from_dict(data) + + assert restored.message == original.message + assert restored.stop_reason == original.stop_reason + assert isinstance(restored.metrics, EventLoopMetrics) + assert restored.state == {} # State is not serialized diff --git a/tests/strands/experimental/hooks/multiagent/__init__.py b/tests/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/hooks/multiagent/test_events.py b/tests/strands/experimental/hooks/multiagent/test_events.py new file mode 100644 index 000000000..6c4d7c4e7 --- /dev/null +++ b/tests/strands/experimental/hooks/multiagent/test_events.py @@ -0,0 +1,107 @@ +"""Tests for multi-agent execution lifecycle events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import BaseHookEvent + + +@pytest.fixture +def orchestrator(): + """Mock orchestrator for testing.""" + return Mock() + + +def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): + """Test MultiAgentInitializedEvent creation with orchestrator only.""" + event = MultiAgentInitializedEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_multi_agent_initialization_event_with_invocation_state(orchestrator): + """Test MultiAgentInitializedEvent creation with invocation state.""" + invocation_state = {"key": "value"} + event = MultiAgentInitializedEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state + + +def test_after_node_invocation_event_with_required_fields(orchestrator): + """Test AfterNodeCallEvent creation with required fields.""" + node_id = "node_1" + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_node_invocation_event_with_invocation_state(orchestrator): + """Test AfterNodeCallEvent creation with invocation state.""" + node_id = "node_2" + invocation_state = {"result": "success"} + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state == invocation_state + + +def test_after_multi_agent_invocation_event_with_orchestrator_only(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with orchestrator only.""" + event = AfterMultiAgentInvocationEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_multi_agent_invocation_event_with_invocation_state(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with invocation state.""" + invocation_state = {"final_state": "completed"} + event = AfterMultiAgentInvocationEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state + + +def test_before_node_call_event(orchestrator): + """Test BeforeNodeCallEvent creation.""" + node_id = "node_1" + event = BeforeNodeCallEvent(source=orchestrator, node_id=node_id) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_before_multi_agent_invocation_event(orchestrator): + """Test BeforeMultiAgentInvocationEvent creation.""" + event = BeforeMultiAgentInvocationEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_events_should_reverse_callbacks(orchestrator): + """Test that After events have should_reverse_callbacks property set to True.""" + after_node_event = AfterNodeCallEvent(source=orchestrator, node_id="test") + after_invocation_event = AfterMultiAgentInvocationEvent(source=orchestrator) + + assert after_node_event.should_reverse_callbacks is True + assert after_invocation_event.should_reverse_callbacks is True diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index ab55b2c84..4e8a5dd06 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -28,6 +28,9 @@ def test_node_result_initialization_and_properties(agent_result): assert node_result.accumulated_metrics == {"latencyMs": 0.0} assert node_result.execution_count == 0 + default_node = NodeResult(result=agent_result) + assert default_node.status == Status.PENDING + # With custom metrics custom_usage = {"inputTokens": 100, "outputTokens": 200, "totalTokens": 300} custom_metrics = {"latencyMs": 250.0} @@ -95,6 +98,7 @@ def test_multi_agent_result_initialization(agent_result): assert result.accumulated_metrics == {"latencyMs": 0.0} assert result.execution_count == 0 assert result.execution_time == 0 + assert result.status == Status.PENDING # Custom values`` node_result = NodeResult(result=agent_result) @@ -141,6 +145,12 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) @@ -164,6 +174,12 @@ async def invoke_async(self, task, invocation_state, **kwargs): status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + agent = TestMultiAgent() # Test with string task @@ -174,3 +190,52 @@ async def invoke_async(self, task, invocation_state, **kwargs): assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"} assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED + + +def test_node_result_to_dict(agent_result): + """Test NodeResult to_dict method.""" + node_result = NodeResult(result=agent_result, execution_time=100, status=Status.COMPLETED) + result_dict = node_result.to_dict() + + assert result_dict["execution_time"] == 100 + assert result_dict["status"] == "completed" + assert result_dict["result"]["type"] == "agent_result" + assert result_dict["result"]["stop_reason"] == agent_result.stop_reason + assert result_dict["result"]["message"] == agent_result.message + + exception_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + result_dict = exception_result.to_dict() + + assert result_dict["result"]["type"] == "exception" + assert result_dict["result"]["message"] == "Test error" + assert result_dict["status"] == "failed" + + +def test_multi_agent_result_to_dict(agent_result): + """Test MultiAgentResult to_dict method.""" + node_result = NodeResult(result=agent_result) + multi_result = MultiAgentResult(status=Status.COMPLETED, results={"test_node": node_result}, execution_time=200) + + result_dict = multi_result.to_dict() + + assert result_dict["status"] == "completed" + assert result_dict["execution_time"] == 200 + assert "test_node" in result_dict["results"] + assert result_dict["results"]["test_node"]["result"]["type"] == "agent_result" + + +def test_serialize_node_result_for_persist(agent_result): + """Test serialize_node_result_for_persist method.""" + + node_result = NodeResult(result=agent_result) + serialized = node_result.to_dict() + + assert "result" in serialized + assert "execution_time" in serialized + assert "status" in serialized + + exception_node_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + serialized_exception = exception_node_result.to_dict() + assert "result" in serialized_exception + assert serialized_exception["result"]["type"] == "exception" + assert serialized_exception["result"]["message"] == "Test error"