|
| 1 | +"""Multi-agent state data structures for session persistence. |
| 2 | +
|
| 3 | +This module defines the core data structures used to represent the state |
| 4 | +of multi-agent orchestrators in a serializable format for session persistence. |
| 5 | +
|
| 6 | +Key Components: |
| 7 | +- MultiAgentType: Enum for orchestrator types (Graph/Swarm) |
| 8 | +- MultiAgentState: Serializable state container with conversion methods |
| 9 | +""" |
| 10 | + |
| 11 | +from dataclasses import dataclass, field |
| 12 | +from enum import Enum |
| 13 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set |
| 14 | + |
| 15 | +from ...types.content import ContentBlock |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from ...multiagent.base import Status |
| 19 | + |
| 20 | + |
| 21 | +# TODO: Move to Base after experimental |
| 22 | +class MultiAgentType(Enum): |
| 23 | + """Enumeration of supported multi-agent orchestrator types. |
| 24 | + |
| 25 | + Attributes: |
| 26 | + SWARM: Collaborative agent swarm orchestrator |
| 27 | + GRAPH: Directed graph-based agent orchestrator |
| 28 | + """ |
| 29 | + SWARM = "swarm" |
| 30 | + GRAPH = "graph" |
| 31 | + |
| 32 | + |
| 33 | +@dataclass |
| 34 | +class MultiAgentState: |
| 35 | + """Serializable state container for multi-agent orchestrators. |
| 36 | + |
| 37 | + This class represents the complete execution state of a multi-agent |
| 38 | + orchestrator (Graph or Swarm) in a format suitable for persistence |
| 39 | + and restoration across sessions. |
| 40 | + |
| 41 | + Attributes: |
| 42 | + completed_nodes: Set of node IDs that have completed execution |
| 43 | + node_results: Dictionary mapping node IDs to their execution results |
| 44 | + status: Current execution status of the orchestrator |
| 45 | + next_node_to_execute: List of node IDs ready for execution |
| 46 | + current_task: The original task being executed |
| 47 | + execution_order: Ordered list of executed node IDs |
| 48 | + error_message: Optional error message if execution failed |
| 49 | + type: Type of orchestrator (Graph or Swarm) |
| 50 | + context: Additional context data (primarily for Swarm) |
| 51 | + """ |
| 52 | + # Mutual |
| 53 | + completed_nodes: Set[str] = field(default_factory=set) |
| 54 | + node_results: Dict[str, Any] = field(default_factory=dict) |
| 55 | + status: "Status" = "pending" |
| 56 | + next_node_to_execute: Optional[List[str]] = None |
| 57 | + current_task: Optional[str | List[ContentBlock]] = None |
| 58 | + execution_order: list[str] = field(default_factory=list) |
| 59 | + error_message: Optional[str] = None |
| 60 | + type: Optional[MultiAgentType] = field(default=MultiAgentType.GRAPH) |
| 61 | + # Swarm |
| 62 | + context: Optional[dict] = field(default_factory=dict) |
| 63 | + |
| 64 | + def to_dict(self) -> dict[str, Any]: |
| 65 | + """Convert MultiAgentState to JSON-serializable dictionary. |
| 66 | + |
| 67 | + Returns: |
| 68 | + Dictionary representation suitable for JSON serialization |
| 69 | + """ |
| 70 | + def _serialize(v: Any) -> Any: |
| 71 | + if isinstance(v, (str, int, float, bool)) or v is None: |
| 72 | + return v |
| 73 | + if isinstance(v, set): |
| 74 | + return list(v) |
| 75 | + if isinstance(v, dict): |
| 76 | + return {str(k): _serialize(val) for k, val in v.items()} |
| 77 | + if isinstance(v, (list, tuple)): |
| 78 | + return [_serialize(x) for x in v] |
| 79 | + if hasattr(v, "to_dict"): |
| 80 | + return v.to_dict() |
| 81 | + # last resort: stringize anything non-serializable (locks, objects, etc.) |
| 82 | + return str(v) |
| 83 | + |
| 84 | + return { |
| 85 | + "status": self.status, |
| 86 | + "completed_nodes": list(self.completed_nodes), |
| 87 | + "next_node_to_execute": list(self.next_node_to_execute) if self.next_node_to_execute else [], |
| 88 | + "node_results": _serialize(self.node_results), |
| 89 | + "current_task": self.current_task, |
| 90 | + "error_message": self.error_message, |
| 91 | + "execution_order": self.execution_order, |
| 92 | + "type": self.type, |
| 93 | + "context": _serialize(self.context), |
| 94 | + } |
| 95 | + |
| 96 | + @classmethod |
| 97 | + def from_dict(cls, data: dict): |
| 98 | + """Create MultiAgentState from dictionary data. |
| 99 | + |
| 100 | + Args: |
| 101 | + data: Dictionary containing state data |
| 102 | + |
| 103 | + Returns: |
| 104 | + MultiAgentState instance |
| 105 | + """ |
| 106 | + data["completed_nodes"] = set(data.get("completed_nodes", [])) |
| 107 | + return cls(**data) |
0 commit comments