generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 449
Enable multi agent session persistence #900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JackYPCOnline
wants to merge
29
commits into
strands-agents:main
Choose a base branch
from
JackYPCOnline:multi-agent-session
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,587
−149
Open
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
c2d49f6
feat: multiagent session interface
JackYPCOnline 49253bc
feat: enable multiagent session persistence
JackYPCOnline 41d34a9
fix: add write file fallback for window permission handling
JackYPCOnline 2e854f8
fix: remove persistence_hooks, use session_manager to subscribe multi…
JackYPCOnline 0380567
Update src/strands/multiagent/base.py
JackYPCOnline 93909d0
Update src/strands/multiagent/base.py
JackYPCOnline 023cfba
Update src/strands/session/session_manager.py
JackYPCOnline d064038
Update src/strands/session/session_manager.py
JackYPCOnline b00f58d
feat: add restricted type check to serialization/deserialization func…
JackYPCOnline d67c848
fix: remove persistence_hooks, use session_manager to subscribe multi…
JackYPCOnline e6b4af2
fix: remove optional from invoke_callbacks
JackYPCOnline 1fcfdc0
fix: fix from_dict consistency
JackYPCOnline 3fe3978
fix: fix from_dic consistency
JackYPCOnline 683a14f
fix: fix file session creation issue
JackYPCOnline 7411f2c
fix: remove completed_nodes, rename execution_order to node_history i…
JackYPCOnline e9c2d57
fix: address comments, adding more tests and integration tests in ne…
JackYPCOnline 8233c62
feat: add more unit tests and integration tests to validate Graph/Swa…
JackYPCOnline 1e9851a
Merge branch 'main' into multi-agent-session
JackYPCOnline f1aac16
fix: fix bad rebase
JackYPCOnline 7735ed3
fix: revert single agent session_manager validator
JackYPCOnline 7b3aabb
fix: refine code structures, address related comments
JackYPCOnline 58424ba
Merge branch 'main' into multi-agent-session
JackYPCOnline 8ed2e21
feat: add BeforeNodeCallEvent to swarm & graph
JackYPCOnline d3adef3
fix: fix bad rebase
JackYPCOnline 80c8169
fix: address comments, move from_dict() to AgentResult, fix docstring…
JackYPCOnline 2ff4035
fix: fix typo and pattern
JackYPCOnline 191f5e0
fix: rename multiagent dictory
JackYPCOnline a1e10ed
fix: address PR comments
JackYPCOnline 734c59b
fix: address comment
JackYPCOnline File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
21 changes: 21 additions & 0 deletions
21
src/strands/experimental/hooks/multiagent_hooks/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Multi-agent session management for persistent execution. | ||
|
||
This package provides session persistence capabilities for multi-agent orchestrators, | ||
enabling resumable execution after interruptions or failures. | ||
""" | ||
|
||
from .multiagent_events import ( | ||
AfterMultiAgentInvocationEvent, | ||
AfterNodeCallEvent, | ||
BeforeMultiAgentInvocationEvent, | ||
BeforeNodeCallEvent, | ||
MultiAgentInitializedEvent, | ||
) | ||
|
||
__all__ = [ | ||
"AfterMultiAgentInvocationEvent", | ||
"AfterNodeCallEvent", | ||
"BeforeMultiAgentInvocationEvent", | ||
"BeforeNodeCallEvent", | ||
"MultiAgentInitializedEvent", | ||
] |
88 changes: 88 additions & 0 deletions
88
src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Multi-agent execution lifecycle events for hook system integration. | ||
These events are fired by orchestrators (Graph/Swarm) at key points so | ||
hooks can persist, monitor, or debug execution. No intermediate state model | ||
is used—hooks read from the orchestrator directly. | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from ....hooks import BaseHookEvent | ||
|
||
if TYPE_CHECKING: | ||
from ....multiagent.base import MultiAgentBase | ||
|
||
|
||
@dataclass | ||
class MultiAgentInitializedEvent(BaseHookEvent): | ||
"""Event triggered when multi-agent orchestrator initialized. | ||
Attributes: | ||
source: The multi-agent orchestrator instance | ||
invocation_state: Configuration that user pass in | ||
""" | ||
|
||
source: "MultiAgentBase" | ||
invocation_state: dict[str, Any] | None = None | ||
|
||
|
||
@dataclass | ||
class BeforeNodeCallEvent(BaseHookEvent): | ||
"""Event triggered before individual node execution completes. This event corresponds to the After event. | ||
Attributes: | ||
source: The multi-agent orchestrator instance | ||
node_id: ID of the node that just completed execution | ||
invocation_state: Configuration that user pass in | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
|
||
source: "MultiAgentBase" | ||
node_id: str | ||
invocation_state: dict[str, Any] | None = None | ||
|
||
|
||
@dataclass | ||
class AfterNodeCallEvent(BaseHookEvent): | ||
"""Event triggered after individual node execution completes. | ||
Attributes: | ||
source: The multi-agent orchestrator instance | ||
node_id: ID of the node that just completed execution | ||
invocation_state: Configuration that user pass in | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
|
||
source: "MultiAgentBase" | ||
node_id: str | ||
invocation_state: dict[str, Any] | None = None | ||
|
||
@property | ||
def should_reverse_callbacks(self) -> bool: | ||
"""True to invoke callbacks in reverse order.""" | ||
return True | ||
|
||
|
||
@dataclass | ||
class BeforeMultiAgentInvocationEvent(BaseHookEvent): | ||
"""Event triggered after orchestrator execution completes. This event corresponds to the After event. | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Attributes: | ||
source: The multi-agent orchestrator instance | ||
invocation_state: Configuration that user pass in | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
|
||
source: "MultiAgentBase" | ||
invocation_state: dict[str, Any] | None = None | ||
|
||
|
||
@dataclass | ||
class AfterMultiAgentInvocationEvent(BaseHookEvent): | ||
"""Event triggered after orchestrator execution completes. | ||
Attributes: | ||
source: The multi-agent orchestrator instance | ||
invocation_state: Configuration that user pass in | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
|
||
source: "MultiAgentBase" | ||
invocation_state: dict[str, Any] | None = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,16 +4,20 @@ | |
""" | ||
|
||
import asyncio | ||
import logging | ||
import warnings | ||
from abc import ABC, abstractmethod | ||
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from typing import Any, Union | ||
from typing import Any, Union, cast | ||
|
||
from ..agent import AgentResult | ||
from ..types.content import ContentBlock | ||
from ..types.event_loop import Metrics, Usage | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..types.content import ContentBlock, Message | ||
from ..types.event_loop import Metrics, StopReason, Usage | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Status(Enum): | ||
|
@@ -59,6 +63,93 @@ def get_agent_results(self) -> list[AgentResult]: | |
flattened.extend(nested_node_result.get_agent_results()) | ||
return flattened | ||
|
||
def to_dict(self) -> dict[str, Any]: | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Convert NodeResult to JSON-serializable dict, ignoring state field.""" | ||
if isinstance(self.result, Exception): | ||
result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} | ||
elif isinstance(self.result, AgentResult): | ||
# Serialize AgentResult without state field | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result_data = { | ||
"type": "agent_result", | ||
"stop_reason": self.result.stop_reason, | ||
"message": self.result.message, | ||
} | ||
elif isinstance(self.result, MultiAgentResult): | ||
result_data = self.result.to_dict() | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise TypeError(f"Unsupported NodeResult.result type for serialization: {type(self.result).__name__}") | ||
|
||
return { | ||
"result": result_data, | ||
"execution_time": self.execution_time, | ||
"status": self.status.value, | ||
"accumulated_usage": self.accumulated_usage, | ||
"accumulated_metrics": self.accumulated_metrics, | ||
"execution_count": self.execution_count, | ||
} | ||
|
||
@classmethod | ||
def from_dict(cls, data: dict[str, Any]) -> "NodeResult": | ||
"""Rehydrate a NodeResult from persisted JSON.""" | ||
if "result" not in data: | ||
raise TypeError("NodeResult.from_dict: missing 'result'") | ||
raw = data["result"] | ||
|
||
result: Union[AgentResult, "MultiAgentResult", Exception] | ||
if isinstance(raw, dict) and raw.get("type") == "agent_result": | ||
result = NodeResult.agent_result_from_persisted(raw) | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
elif isinstance(raw, dict) and raw.get("type") == "exception": | ||
result = Exception(str(raw.get("message", "node failed"))) | ||
elif isinstance(raw, dict) and ("results" in raw): | ||
result = MultiAgentResult.from_dict(raw) | ||
else: | ||
raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") | ||
|
||
usage_data = data.get("accumulated_usage", {}) | ||
usage = Usage( | ||
inputTokens=usage_data.get("inputTokens", 0), | ||
outputTokens=usage_data.get("outputTokens", 0), | ||
totalTokens=usage_data.get("totalTokens", 0), | ||
) | ||
# Add optional fields if they exist | ||
if "cacheReadInputTokens" in usage_data: | ||
usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] | ||
if "cacheWriteInputTokens" in usage_data: | ||
usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] | ||
|
||
metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) | ||
|
||
return cls( | ||
result=result, | ||
execution_time=int(data.get("execution_time", 0)), | ||
status=Status(data.get("status", "pending")), | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
accumulated_usage=usage, | ||
accumulated_metrics=metrics, | ||
execution_count=int(data.get("execution_count", 0)), | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
@classmethod | ||
def agent_result_from_persisted(cls, data: dict[str, Any]) -> AgentResult: | ||
"""Rehydrate a minimal AgentResult from persisted JSON. | ||
|
||
Expected shape: | ||
{"type": "agent_result", "message": <Message>, "stop_reason": <str|None>} | ||
""" | ||
if data.get("type") != "agent_result": | ||
raise TypeError(f"agent_result_from_persisted: unexpected type {data.get('type')!r}") | ||
|
||
message = cast(Message, data.get("message")) | ||
stop_reason = cast( | ||
StopReason, | ||
data.get("stop_reason"), | ||
) | ||
|
||
try: | ||
return AgentResult(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) | ||
except Exception: | ||
logger.debug("AgentResult constructor failed during rehydrating") | ||
raise | ||
|
||
|
||
@dataclass | ||
class MultiAgentResult: | ||
|
@@ -76,6 +167,46 @@ class MultiAgentResult: | |
execution_count: int = 0 | ||
execution_time: int = 0 | ||
|
||
def to_dict(self) -> dict[str, Any]: | ||
"""Convert MultiAgentResult to JSON-serializable dict.""" | ||
return { | ||
"type": "mutiagent_result", | ||
"status": self.status.value, | ||
"results": {k: v.to_dict() for k, v in self.results.items()}, | ||
"accumulated_usage": dict(self.accumulated_usage), | ||
"accumulated_metrics": dict(self.accumulated_metrics), | ||
"execution_count": self.execution_count, | ||
"execution_time": self.execution_time, | ||
} | ||
|
||
@classmethod | ||
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": | ||
"""Rehydrate a MultiAgentResult from persisted JSON.""" | ||
results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} | ||
usage_data = data.get("accumulated_usage", {}) | ||
usage = Usage( | ||
inputTokens=usage_data.get("inputTokens", 0), | ||
outputTokens=usage_data.get("outputTokens", 0), | ||
totalTokens=usage_data.get("totalTokens", 0), | ||
) | ||
# Add optional fields if they exist | ||
if "cacheReadInputTokens" in usage_data: | ||
usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] | ||
if "cacheWriteInputTokens" in usage_data: | ||
usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] | ||
|
||
metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) | ||
|
||
multiagent_result = cls( | ||
status=Status(data.get("status", Status.PENDING.value)), | ||
results=results, | ||
accumulated_usage=usage, | ||
accumulated_metrics=metrics, | ||
execution_count=int(data.get("execution_count", 0)), | ||
execution_time=int(data.get("execution_time", 0)), | ||
) | ||
return multiagent_result | ||
|
||
|
||
class MultiAgentBase(ABC): | ||
"""Base class for multi-agent helpers. | ||
|
@@ -122,3 +253,22 @@ def execute() -> MultiAgentResult: | |
with ThreadPoolExecutor() as executor: | ||
future = executor.submit(execute) | ||
return future.result() | ||
|
||
def serialize_state(self) -> dict[str, Any]: | ||
"""Return a JSON-serializable snapshot of the orchestrator state.""" | ||
raise NotImplementedError | ||
|
||
def deserialize_state(self, payload: dict[str, Any]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't see deserialize as being a mutative action. If this is actually doing
why not call restore_from_state or session |
||
"""Restore orchestrator state from a session dict.""" | ||
raise NotImplementedError | ||
|
||
def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"""Serialize node result for persistence. | ||
|
||
Args: | ||
raw: Raw node result to serialize | ||
|
||
Returns: | ||
JSON-serializable dict representation | ||
""" | ||
return raw.to_dict() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.