-
Notifications
You must be signed in to change notification settings - Fork 450
Enable multi agent session persistence #900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 20 commits
c2d49f6
49253bc
41d34a9
2e854f8
0380567
93909d0
023cfba
d064038
b00f58d
d67c848
e6b4af2
1fcfdc0
3fe3978
683a14f
7411f2c
e9c2d57
8233c62
1e9851a
f1aac16
7735ed3
7b3aabb
58424ba
8ed2e21
d3adef3
80c8169
2ff4035
191f5e0
a1e10ed
734c59b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| """Multi-agent session management for persistent execution. | ||
| This package provides session persistence capabilities for multi-agent orchestrators, | ||
| enabling resumable execution after interruptions or failures. | ||
| """ | ||
|
|
||
| from .multiagent_events import ( | ||
| AfterMultiAgentInvocationEvent, | ||
| AfterNodeInvocationEvent, | ||
| BeforeNodeInvocationEvent, | ||
| MultiagentInitializedEvent, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "AfterMultiAgentInvocationEvent", | ||
| "MultiagentInitializedEvent", | ||
| "AfterNodeInvocationEvent", | ||
| "BeforeNodeInvocationEvent", | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| """Multi-agent execution lifecycle events for hook system integration. | ||
| These events are fired by orchestrators (Graph/Swarm) at key points so | ||
| hooks can persist, monitor, or debug execution. No intermediate state model | ||
| is used—hooks read from the orchestrator directly. | ||
| """ | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from ...hooks.registry import BaseHookEvent | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ...multiagent.base import MultiAgentBase | ||
|
|
||
|
|
||
| @dataclass | ||
| class MultiagentInitializedEvent(BaseHookEvent): | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Event triggered when multi-agent orchestrator initialized. | ||
| Attributes: | ||
| source: The multi-agent orchestrator instance | ||
| invocation_state: Configuration that user pass in | ||
| """ | ||
|
|
||
| source: "MultiAgentBase" | ||
| invocation_state: dict[str, Any] | None = None | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @dataclass | ||
| class BeforeNodeInvocationEvent(BaseHookEvent): | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Event triggered before individual node execution completes. This event corresponds to the After event.""" | ||
|
|
||
| source: "MultiAgentBase" | ||
| invocation_state: dict[str, Any] | None = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class AfterNodeInvocationEvent(BaseHookEvent): | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Event triggered after individual node execution completes. | ||
| Attributes: | ||
| source: The multi-agent orchestrator instance | ||
| executed_node: ID of the node that just completed execution | ||
| invocation_state: Configuration that user pass in | ||
| """ | ||
|
|
||
| source: "MultiAgentBase" | ||
| executed_node: str | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| invocation_state: dict[str, Any] | None = None | ||
|
|
||
| @property | ||
| def should_reverse_callbacks(self) -> bool: | ||
| """True to invoke callbacks in reverse order.""" | ||
| return True | ||
|
|
||
|
|
||
| @dataclass | ||
| class BeforeMultiAgentInvocationEvent(BaseHookEvent): | ||
| """Event triggered after orchestrator execution completes. This event corresponds to the After event. | ||
| Attributes: | ||
| source: The multi-agent orchestrator instance | ||
| invocation_state: Configuration that user pass in | ||
| """ | ||
|
|
||
| source: "MultiAgentBase" | ||
| invocation_state: dict[str, Any] | None = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class AfterMultiAgentInvocationEvent(BaseHookEvent): | ||
| """Event triggered after orchestrator execution completes. | ||
| Attributes: | ||
| source: The multi-agent orchestrator instance | ||
| invocation_state: Configuration that user pass in | ||
| """ | ||
|
|
||
| source: "MultiAgentBase" | ||
| invocation_state: dict[str, Any] | None = None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,17 +4,21 @@ | |
| """ | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import warnings | ||
| from abc import ABC, abstractmethod | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from dataclasses import dataclass, field | ||
| from enum import Enum | ||
| from typing import Any, Union | ||
| from typing import Any, Literal, Union, cast | ||
|
|
||
| from ..agent import AgentResult | ||
| from ..types.content import ContentBlock | ||
| from ..telemetry.metrics import EventLoopMetrics | ||
| from ..types.content import ContentBlock, Message | ||
| from ..types.event_loop import Metrics, Usage | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class Status(Enum): | ||
| """Execution status for both graphs and nodes.""" | ||
|
|
@@ -52,13 +56,101 @@ def get_agent_results(self) -> list[AgentResult]: | |
| return [] # No agent results for exceptions | ||
| elif isinstance(self.result, AgentResult): | ||
| return [self.result] | ||
| # If this is a nested MultiAgentResult, flatten children | ||
| else: | ||
| # Flatten nested results from MultiAgentResult | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| flattened = [] | ||
| for nested_node_result in self.result.results.values(): | ||
| flattened.extend(nested_node_result.get_agent_results()) | ||
| return flattened | ||
|
|
||
| def to_dict(self) -> dict[str, Any]: | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Convert NodeResult to JSON-serializable dict, ignoring state field.""" | ||
| if isinstance(self.result, Exception): | ||
| result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} | ||
| elif isinstance(self.result, AgentResult): | ||
| # Serialize AgentResult without state field | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| result_data = { | ||
| "type": "agent_result", | ||
| "stop_reason": self.result.stop_reason, | ||
| "message": self.result.message, | ||
| } | ||
| elif isinstance(self.result, MultiAgentResult): | ||
| result_data = self.result.to_dict() | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| raise TypeError(f"Unsupported NodeResult.result type for serialization: {type(self.result).__name__}") | ||
|
|
||
| return { | ||
| "result": result_data, | ||
| "execution_time": self.execution_time, | ||
| "status": self.status.value, | ||
| "accumulated_usage": self.accumulated_usage, | ||
| "accumulated_metrics": self.accumulated_metrics, | ||
| "execution_count": self.execution_count, | ||
| } | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, data: dict[str, Any]) -> "NodeResult": | ||
| """Rehydrate a NodeResult from persisted JSON.""" | ||
| if "result" not in data: | ||
| raise TypeError("NodeResult.from_dict: missing 'result'") | ||
| raw = data["result"] | ||
|
|
||
| result: Union[AgentResult, "MultiAgentResult", Exception] | ||
| if isinstance(raw, dict) and raw.get("type") == "agent_result": | ||
| result = _agent_result_from_persisted(raw) | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| elif isinstance(raw, dict) and raw.get("type") == "exception": | ||
| result = Exception(str(raw.get("message", "node failed"))) | ||
| elif isinstance(raw, dict) and ("results" in raw): | ||
| result = MultiAgentResult.from_dict(raw) | ||
| else: | ||
| raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") | ||
|
|
||
| usage_data = data.get("accumulated_usage", {}) | ||
| usage = Usage( | ||
| inputTokens=usage_data.get("inputTokens", 0), | ||
| outputTokens=usage_data.get("outputTokens", 0), | ||
| totalTokens=usage_data.get("totalTokens", 0), | ||
| ) | ||
| # Add optional fields if they exist | ||
| if "cacheReadInputTokens" in usage_data: | ||
| usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] | ||
| if "cacheWriteInputTokens" in usage_data: | ||
| usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] | ||
|
|
||
| metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) | ||
|
|
||
| return cls( | ||
| result=result, | ||
| execution_time=int(data.get("execution_time", 0)), | ||
| status=Status(data.get("status", "pending")), | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| accumulated_usage=usage, | ||
| accumulated_metrics=metrics, | ||
| execution_count=int(data.get("execution_count", 0)), | ||
JackYPCOnline marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
|
|
||
| def _agent_result_from_persisted(data: dict[str, Any]) -> AgentResult: | ||
| """Rehydrate a minimal AgentResult from persisted JSON. | ||
|
|
||
| Expected shape: | ||
| {"type": "agent_result", "message": <Message>, "stop_reason": <str|None>} | ||
| """ | ||
| if data.get("type") != "agent_result": | ||
| raise TypeError(f"_agent_result_from_persisted: unexpected type {data.get('type')!r}") | ||
|
|
||
| message = cast(Message, data.get("message")) | ||
| stop_reason = cast( | ||
| Literal["content_filtered", "end_turn", "guardrail_intervened", "max_tokens", "stop_sequence", "tool_use"], | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| data.get("stop_reason"), | ||
| ) | ||
|
|
||
| try: | ||
| return AgentResult(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) | ||
| except Exception: | ||
| logger.debug("AgentResult constructor failed during rehydrating") | ||
| raise | ||
|
|
||
|
|
||
| @dataclass | ||
| class MultiAgentResult: | ||
|
|
@@ -76,6 +168,45 @@ class MultiAgentResult: | |
| execution_count: int = 0 | ||
| execution_time: int = 0 | ||
|
|
||
| def to_dict(self) -> dict[str, Any]: | ||
| """Convert MultiAgentResult to JSON-serializable dict.""" | ||
| return { | ||
| "status": self.status.value, | ||
| "results": {k: v.to_dict() for k, v in self.results.items()}, | ||
| "accumulated_usage": dict(self.accumulated_usage), | ||
| "accumulated_metrics": dict(self.accumulated_metrics), | ||
| "execution_count": self.execution_count, | ||
| "execution_time": self.execution_time, | ||
| } | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": | ||
| """Rehydrate a MultiAgentResult from persisted JSON.""" | ||
| results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} | ||
| usage_data = data.get("accumulated_usage", {}) | ||
| usage = Usage( | ||
| inputTokens=usage_data.get("inputTokens", 0), | ||
| outputTokens=usage_data.get("outputTokens", 0), | ||
| totalTokens=usage_data.get("totalTokens", 0), | ||
| ) | ||
| # Add optional fields if they exist | ||
| if "cacheReadInputTokens" in usage_data: | ||
| usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] | ||
| if "cacheWriteInputTokens" in usage_data: | ||
| usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] | ||
|
|
||
| metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) | ||
|
|
||
| multiagent_result = cls( | ||
| status=Status(data.get("status", Status.PENDING.value)), | ||
| results=results, | ||
| accumulated_usage=usage, | ||
| accumulated_metrics=metrics, | ||
| execution_count=int(data.get("execution_count", 0)), | ||
| execution_time=int(data.get("execution_time", 0)), | ||
| ) | ||
| return multiagent_result | ||
|
|
||
|
|
||
| class MultiAgentBase(ABC): | ||
| """Base class for multi-agent helpers. | ||
|
|
@@ -122,3 +253,32 @@ def execute() -> MultiAgentResult: | |
| with ThreadPoolExecutor() as executor: | ||
| future = executor.submit(execute) | ||
| return future.result() | ||
|
|
||
| def serialize_state(self) -> dict[str, Any]: | ||
| """Return a JSON-serializable snapshot of the orchestrator state.""" | ||
| raise NotImplementedError | ||
|
|
||
| def deserialize_state(self, payload: dict[str, Any]) -> None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't see deserialize as being a mutative action. If this is actually doing
why not call restore_from_state or session |
||
| """Restore orchestrator state from a session dict.""" | ||
| raise NotImplementedError | ||
|
|
||
| def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Serialize node result for persistence. | ||
|
|
||
| Args: | ||
| raw: Raw node result to serialize | ||
|
|
||
| Returns: | ||
| JSON-serializable dict representation | ||
| """ | ||
| if not isinstance(raw, NodeResult): | ||
| raise TypeError(f"serialize_node_result_for_persist expects NodeResult, got {type(raw).__name__}") | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return raw.to_dict() | ||
|
|
||
| def attempt_resume(self, payload: dict[str, Any]) -> None: | ||
JackYPCOnline marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Attempt to resume orchestrator state from a session payload. | ||
|
|
||
| Args: | ||
| payload: Session data to restore orchestrator state from | ||
| """ | ||
| raise NotImplementedError | ||
Uh oh!
There was an error while loading. Please reload this page.