Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from dataclasses import dataclass
from typing import Any, Sequence
from typing import Any, Sequence, cast

from ..interrupt import Interrupt
from ..telemetry.metrics import EventLoopMetrics
Expand Down Expand Up @@ -46,3 +46,34 @@ def __str__(self) -> str:
if isinstance(item, dict) and "text" in item:
result += item.get("text", "") + "\n"
return result

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AgentResult":
"""Rehydrate an AgentResult from persisted JSON.

Args:
data: Dictionary containing the serialized AgentResult data
Returns:
AgentResult instance
Raises:
TypeError: If the data format is invalid@
"""
if data.get("type") != "agent_result":
raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}")

message = cast(Message, data.get("message"))
stop_reason = cast(StopReason, data.get("stop_reason"))

return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={})

def to_dict(self) -> dict[str, Any]:
"""Convert this AgentResult to JSON-serializable dictionary.

Returns:
Dictionary containing serialized AgentResult data
"""
return {
"type": "agent_result",
"message": self.message,
"stop_reason": self.stop_reason,
}
20 changes: 20 additions & 0 deletions src/strands/experimental/hooks/multiagent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Multi-agent hook events and utilities.

Provides event classes for hooking into multi-agent orchestrator lifecycle.
"""

from .events import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
BeforeMultiAgentInvocationEvent,
BeforeNodeCallEvent,
MultiAgentInitializedEvent,
)

__all__ = [
"AfterMultiAgentInvocationEvent",
"AfterNodeCallEvent",
"BeforeMultiAgentInvocationEvent",
"BeforeNodeCallEvent",
"MultiAgentInitializedEvent",
]
93 changes: 93 additions & 0 deletions src/strands/experimental/hooks/multiagent/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Multi-agent execution lifecycle events for hook system integration.

These events are fired by orchestrators (Graph/Swarm) at key points so
hooks can persist, monitor, or debug execution. No intermediate state model
is used—hooks read from the orchestrator directly.
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from ....hooks import BaseHookEvent

if TYPE_CHECKING:
from ....multiagent.base import MultiAgentBase


@dataclass
class MultiAgentInitializedEvent(BaseHookEvent):
"""Event triggered when multi-agent orchestrator initialized.

Attributes:
source: The multi-agent orchestrator instance
invocation_state: Configuration that user passes in
"""

source: "MultiAgentBase"
invocation_state: dict[str, Any] | None = None


@dataclass
class BeforeNodeCallEvent(BaseHookEvent):
"""Event triggered before individual node execution starts.

Attributes:
source: The multi-agent orchestrator instance
node_id: ID of the node about to execute
invocation_state: Configuration that user passes in
"""

source: "MultiAgentBase"
node_id: str
invocation_state: dict[str, Any] | None = None


@dataclass
class AfterNodeCallEvent(BaseHookEvent):
"""Event triggered after individual node execution completes.

Attributes:
source: The multi-agent orchestrator instance
node_id: ID of the node that just completed execution
invocation_state: Configuration that user passes in
"""

source: "MultiAgentBase"
node_id: str
invocation_state: dict[str, Any] | None = None

@property
def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True


@dataclass
class BeforeMultiAgentInvocationEvent(BaseHookEvent):
"""Event triggered before orchestrator execution starts.

Attributes:
source: The multi-agent orchestrator instance
invocation_state: Configuration that user passes in
"""

source: "MultiAgentBase"
invocation_state: dict[str, Any] | None = None


@dataclass
class AfterMultiAgentInvocationEvent(BaseHookEvent):
"""Event triggered after orchestrator execution completes.

Attributes:
source: The multi-agent orchestrator instance
invocation_state: Configuration that user passes in
"""

source: "MultiAgentBase"
invocation_state: dict[str, Any] | None = None

@property
def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True
114 changes: 114 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import asyncio
import logging
import warnings
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -15,6 +16,8 @@
from ..types.content import ContentBlock
from ..types.event_loop import Metrics, Usage

logger = logging.getLogger(__name__)


class Status(Enum):
"""Execution status for both graphs and nodes."""
Expand Down Expand Up @@ -59,6 +62,54 @@ def get_agent_results(self) -> list[AgentResult]:
flattened.extend(nested_node_result.get_agent_results())
return flattened

def to_dict(self) -> dict[str, Any]:
"""Convert NodeResult to JSON-serializable dict, ignoring state field."""
if isinstance(self.result, Exception):
result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)}
elif isinstance(self.result, AgentResult):
result_data = self.result.to_dict()
else:
# MultiAgentResult case
result_data = self.result.to_dict()

return {
"result": result_data,
"execution_time": self.execution_time,
"status": self.status.value,
"accumulated_usage": self.accumulated_usage,
"accumulated_metrics": self.accumulated_metrics,
"execution_count": self.execution_count,
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
"""Rehydrate a NodeResult from persisted JSON."""
if "result" not in data:
raise TypeError("NodeResult.from_dict: missing 'result'")
raw = data["result"]

result: Union[AgentResult, "MultiAgentResult", Exception]
if isinstance(raw, dict) and raw.get("type") == "agent_result":
result = AgentResult.from_dict(raw)
elif isinstance(raw, dict) and raw.get("type") == "exception":
result = Exception(str(raw.get("message", "node failed")))
elif isinstance(raw, dict) and raw.get("type") == "multiagent_result":
result = MultiAgentResult.from_dict(raw)
else:
raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}")

usage = _parse_usage(data.get("accumulated_usage", {}))
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

return cls(
result=result,
execution_time=int(data.get("execution_time", 0)),
status=Status(data.get("status", "pending")),
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=int(data.get("execution_count", 0)),
)


@dataclass
class MultiAgentResult:
Expand All @@ -76,6 +127,38 @@ class MultiAgentResult:
execution_count: int = 0
execution_time: int = 0

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
"""Rehydrate a MultiAgentResult from persisted JSON."""
if data.get("type") != "multiagent_result":
raise TypeError(f"MultiAgentResult.from_dict: unexpected type {data.get('type')!r}")

results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()}
usage = _parse_usage(data.get("accumulated_usage", {}))
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

multiagent_result = cls(
status=Status(data.get("status", Status.PENDING.value)),
results=results,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=int(data.get("execution_count", 0)),
execution_time=int(data.get("execution_time", 0)),
)
return multiagent_result

def to_dict(self) -> dict[str, Any]:
"""Convert MultiAgentResult to JSON-serializable dict."""
return {
"type": "multiagent_result",
"status": self.status.value,
"results": {k: v.to_dict() for k, v in self.results.items()},
"accumulated_usage": self.accumulated_usage,
"accumulated_metrics": self.accumulated_metrics,
"execution_count": self.execution_count,
"execution_time": self.execution_time,
}


class MultiAgentBase(ABC):
"""Base class for multi-agent helpers.
Expand Down Expand Up @@ -122,3 +205,34 @@ def execute() -> MultiAgentResult:
with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()

def serialize_state(self) -> dict[str, Any]:
"""Return a JSON-serializable snapshot of the orchestrator state."""
raise NotImplementedError

def deserialize_state(self, payload: dict[str, Any]) -> None:
"""Restore orchestrator state from a session dict."""
raise NotImplementedError


# Private helper function to avoid duplicate code


def _parse_usage(usage_data: dict[str, Any]) -> Usage:
"""Parse Usage from dict data."""
usage = Usage(
inputTokens=usage_data.get("inputTokens", 0),
outputTokens=usage_data.get("outputTokens", 0),
totalTokens=usage_data.get("totalTokens", 0),
)
# Add optional fields if they exist
if "cacheReadInputTokens" in usage_data:
usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"]
if "cacheWriteInputTokens" in usage_data:
usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"]
return usage


def _parse_metrics(metrics_data: dict[str, Any]) -> Metrics:
"""Parse Metrics from dict data."""
return Metrics(latencyMs=metrics_data.get("latencyMs", 0))
41 changes: 41 additions & 0 deletions tests/fixtures/mock_multiagent_hook_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Iterator, Literal, Tuple, Type

from strands.experimental.hooks.multiagent.events import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
BeforeNodeCallEvent,
MultiAgentInitializedEvent,
)
from strands.hooks import (
HookEvent,
HookProvider,
HookRegistry,
)


class MockMultiAgentHookProvider(HookProvider):
def __init__(self, event_types: list[Type] | Literal["all"]):
if event_types == "all":
event_types = [
MultiAgentInitializedEvent,
BeforeNodeCallEvent,
AfterNodeCallEvent,
AfterMultiAgentInvocationEvent,
]

self.events_received = []
self.events_types = event_types

@property
def event_types_received(self):
return [type(event) for event in self.events_received]

def get_events(self) -> Tuple[int, Iterator[HookEvent]]:
return len(self.events_received), iter(self.events_received)

def register_hooks(self, registry: HookRegistry) -> None:
for event_type in self.events_types:
registry.add_callback(event_type, self.add_event)

def add_event(self, event: HookEvent) -> None:
self.events_received.append(event)
45 changes: 45 additions & 0 deletions tests/strands/agent/test_agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,48 @@ def test__str__non_dict_content(mock_metrics):

message_string = str(result)
assert message_string == "Valid text\nMore valid text\n"


def test_to_dict(mock_metrics, simple_message: Message):
"""Test that to_dict serializes AgentResult correctly."""
result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={"key": "value"})

data = result.to_dict()

assert data == {
"type": "agent_result",
"message": simple_message,
"stop_reason": "end_turn",
}


def test_from_dict():
"""Test that from_dict works with valid data."""
data = {
"type": "agent_result",
"message": {"role": "assistant", "content": [{"text": "Test response"}]},
"stop_reason": "end_turn",
}

result = AgentResult.from_dict(data)

assert result.message == data["message"]
assert result.stop_reason == data["stop_reason"]
assert isinstance(result.metrics, EventLoopMetrics)
assert result.state == {}


def test_roundtrip_serialization(mock_metrics, complex_message: Message):
"""Test that to_dict() and from_dict() work together correctly."""
original = AgentResult(
stop_reason="max_tokens", message=complex_message, metrics=mock_metrics, state={"test": "data"}
)

# Serialize and deserialize
data = original.to_dict()
restored = AgentResult.from_dict(data)

assert restored.message == original.message
assert restored.stop_reason == original.stop_reason
assert isinstance(restored.metrics, EventLoopMetrics)
assert restored.state == {} # State is not serialized
Empty file.
Loading
Loading