diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 1628a8a9d..9ab107bb9 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -137,7 +137,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": metrics = _parse_metrics(data.get("accumulated_metrics", {})) multiagent_result = cls( - status=Status(data.get("status", Status.PENDING.value)), + status=Status(data["status"]), results=results, accumulated_usage=usage, accumulated_metrics=metrics, @@ -164,8 +164,13 @@ class MultiAgentBase(ABC): This class integrates with existing Strands Agent instances and provides multi-agent orchestration capabilities. + + Attributes: + id: Unique MultiAgent id for session management,etc. """ + id: str + @abstractmethod async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 491f7ad60..fc80fc520 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,7 +5,7 @@ import os import shutil import tempfile -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from .. import _identifier from ..types.exceptions import SessionException @@ -13,11 +13,15 @@ from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" AGENT_PREFIX = "agent_" MESSAGE_PREFIX = "message_" +MULTI_AGENT_PREFIX = "multi_agent_" class FileSessionManager(RepositorySessionManager, SessionRepository): @@ -37,7 +41,12 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): ``` """ - def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): + def __init__( + self, + session_id: str, + storage_dir: Optional[str] = None, + **kwargs: Any, + ): """Initialize FileSession with filesystem storage. Args: @@ -107,8 +116,11 @@ def _read_file(self, path: str) -> dict[str, Any]: def _write_file(self, path: str, data: dict[str, Any]) -> None: """Write JSON file.""" os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w", encoding="utf-8") as f: + # This automic write ensure the completeness of session files in both single agent/ multi agents + tmp = f"{path}.tmp" + with open(tmp, "w", encoding="utf-8", newline="\n") as f: json.dump(data, f, indent=2, ensure_ascii=False) + os.replace(tmp, path) def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session.""" @@ -119,6 +131,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: # Create directory structure os.makedirs(session_dir, exist_ok=True) os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) + os.makedirs(os.path.join(session_dir, "multi_agents"), exist_ok=True) # Write session file session_file = os.path.join(session_dir, "session.json") @@ -239,3 +252,36 @@ def list_messages( messages.append(SessionMessage.from_dict(message_data)) return messages + + def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str: + """Get multi-agent state file path.""" + session_path = self._get_session_path(session_id) + multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT) + return os.path.join(session_path, "multi_agents", f"{MULTI_AGENT_PREFIX}{multi_agent_id}") + + def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Create a new multiagent state in the session.""" + multi_agent_id = multi_agent.id + multi_agent_dir = self._get_multi_agent_path(session_id, multi_agent_id) + os.makedirs(multi_agent_dir, exist_ok=True) + + multi_agent_file = os.path.join(multi_agent_dir, "multi_agent.json") + session_data = multi_agent.serialize_state() + self._write_file(multi_agent_file, session_data) + + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + """Read multi-agent state from filesystem.""" + multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") + if not os.path.exists(multi_agent_file): + return None + return self._read_file(multi_agent_file) + + def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Update multi-agent state from filesystem.""" + multi_agent_state = multi_agent.serialize_state() + previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id) + if previous_multi_agent_state is None: + raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist") + + multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent.id), "multi_agent.json") + self._write_file(multi_agent_file, multi_agent_state) diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index e5075de93..ccdcb6934 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..multiagent.base import MultiAgentBase logger = logging.getLogger(__name__) @@ -24,7 +25,12 @@ class RepositorySessionManager(SessionManager): """Session manager for persisting agents in a SessionRepository.""" - def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any): + def __init__( + self, + session_id: str, + session_repository: SessionRepository, + **kwargs: Any, + ): """Initialize the RepositorySessionManager. If no session with the specified session_id exists yet, it will be created @@ -152,3 +158,26 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: # Restore the agents messages array including the optional prepend messages agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages] + + def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: + """Serialize and update the multi-agent state into the session repository. + + Args: + source: Multi-agent source object to sync to the session. + **kwargs: Additional keyword arguments for future extensibility. + """ + self.session_repository.update_multi_agent(self.session_id, source) + + def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: + """Initialize multi-agent state from the session repository. + + Args: + source: Multi-agent source object to restore state into + **kwargs: Additional keyword arguments for future extensibility. + """ + state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + if state is None: + self.session_repository.create_multi_agent(self.session_id, source) + else: + logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id) + source.deserialize_state(state) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index c6ce28d80..7d081cf09 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -14,11 +14,15 @@ from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" AGENT_PREFIX = "agent_" MESSAGE_PREFIX = "message_" +MULTI_AGENT_PREFIX = "multi_agent_" class S3SessionManager(RepositorySessionManager, SessionRepository): @@ -294,3 +298,31 @@ def list_messages( except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e + + def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str: + """Get multi-agent S3 prefix.""" + session_path = self._get_session_path(session_id) + multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT) + return f"{session_path}multi_agents/{MULTI_AGENT_PREFIX}{multi_agent_id}/" + + def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Create a new multiagent state in S3.""" + multi_agent_id = multi_agent.id + multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" + session_data = multi_agent.serialize_state() + self._write_s3_object(multi_agent_key, session_data) + + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + """Read multi-agent state from S3.""" + multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" + return self._read_s3_object(multi_agent_key) + + def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Update multi-agent state in S3.""" + multi_agent_state = multi_agent.serialize_state() + previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id) + if previous_multi_agent_state is None: + raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist") + + multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent.id)}multi_agent.json" + self._write_s3_object(multi_agent_key, multi_agent_state) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 66a07ea43..fb9132828 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -1,14 +1,23 @@ """Session manager interface for agent session management.""" +import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from ..experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message if TYPE_CHECKING: from ..agent.agent import Agent + from ..multiagent.base import MultiAgentBase + +logger = logging.getLogger(__name__) class SessionManager(HookProvider, ABC): @@ -34,6 +43,10 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: # After an agent was invoked, sync it with the session to capture any conversation manager state updates registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) + registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) + registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) + registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) + @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: """Redact the message most recently appended to the agent in the session. @@ -71,3 +84,33 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent: Agent to initialize **kwargs: Additional keyword arguments for future extensibility. """ + + def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: + """Serialize and sync multi-agent with the session storage. + + Args: + source: Multi-agent source object to persist + **kwargs: Additional keyword arguments for future extensibility. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(sync_multi_agent). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) + + def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: + """Read multi-agent state from persistent storage. + + Args: + **kwargs: Additional keyword arguments for future extensibility. + source: Multi-agent state to initialize. + + Returns: + Multi-agent state dictionary or empty dict if not found. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(initialize_multi_agent). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 6b0fded7a..3f5476bdf 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -1,10 +1,13 @@ """Session repository interface for agent session management.""" from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from ..types.session import Session, SessionAgent, SessionMessage +if TYPE_CHECKING: + from ..multiagent import MultiAgentBase + class SessionRepository(ABC): """Abstract repository for creating, reading, and updating Sessions, AgentSessions, and AgentMessages.""" @@ -49,3 +52,15 @@ def list_messages( self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List Messages from an Agent with pagination.""" + + def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Create a new MultiAgent state for the Session.""" + raise NotImplementedError("MultiAgent is not implemented for this repository") + + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + """Read the MultiAgent state for the Session.""" + raise NotImplementedError("MultiAgent is not implemented for this repository") + + def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None: + """Update the MultiAgent state for the Session.""" + raise NotImplementedError("MultiAgent is not implemented for this repository") diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 926480f2c..4e72a1468 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -17,7 +17,7 @@ class SessionType(str, Enum): """Enumeration of session types. - As sessions are expanded to support new usecases like multi-agent patterns, + As sessions are expanded to support new use cases like multi-agent patterns, new types will be added here. """ diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py index f3923f68b..af369ba1c 100644 --- a/tests/fixtures/mock_session_repository.py +++ b/tests/fixtures/mock_session_repository.py @@ -11,6 +11,7 @@ def __init__(self): self.sessions = {} self.agents = {} self.messages = {} + self.multi_agents = {} def create_session(self, session) -> None: """Create a session.""" @@ -20,6 +21,7 @@ def create_session(self, session) -> None: self.sessions[session_id] = session self.agents[session_id] = {} self.messages[session_id] = {} + self.multi_agents[session_id] = {} def read_session(self, session_id) -> SessionAgent: """Read a session.""" @@ -95,3 +97,27 @@ def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[Sess if limit is not None: return sorted_messages[offset : offset + limit] return sorted_messages[offset:] + + def create_multi_agent(self, session_id, multi_agent, **kwargs) -> None: + """Create multi-agent state.""" + multi_agent_id = multi_agent.id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + state = multi_agent.serialize_state() + self.multi_agents.setdefault(session_id, {})[multi_agent_id] = state + + def read_multi_agent(self, session_id, multi_agent_id, **kwargs): + """Read multi-agent state.""" + if session_id not in self.sessions: + return None + return self.multi_agents.get(session_id, {}).get(multi_agent_id) + + def update_multi_agent(self, session_id, multi_agent, **kwargs) -> None: + """Update multi-agent state.""" + multi_agent_id = multi_agent.id + if session_id not in self.sessions: + raise SessionException(f"Session {session_id} does not exist") + if multi_agent_id not in self.multi_agents.get(session_id, {}): + raise SessionException(f"MultiAgent {multi_agent} does not exist in session {session_id}") + state = multi_agent.serialize_state() + self.multi_agents[session_id][multi_agent_id] = state diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f124ddf58..7e28be998 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -3,7 +3,7 @@ import json import os import tempfile -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -53,6 +53,22 @@ def sample_message(): ) +@pytest.fixture +def mock_multi_agent(): + """Create mock multi-agent for testing.""" + mock = Mock() + mock.id = "test-multi-agent" + mock.state = {"key": "value"} + mock.serialize_state.return_value = {"id": "test-multi-agent", "state": {"key": "value"}} + return mock + + +@pytest.fixture +def multi_agent_manager(temp_dir): + """Create FileSessionManager.""" + return FileSessionManager(session_id="test", storage_dir=temp_dir) + + def test_create_session(file_manager, sample_session): """Test creating a session.""" file_manager.create_session(sample_session) @@ -408,3 +424,80 @@ def test__get_message_path_invalid_message_id(message_id, file_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): file_manager._get_message_path("session1", "agent1", message_id) + + +def test_create_multi_agent(multi_agent_manager, sample_session, mock_multi_agent): + """Test creating multi-agent state.""" + multi_agent_manager.create_session(sample_session) + multi_agent_manager.create_multi_agent(sample_session.session_id, mock_multi_agent) + + # Verify file created + multi_agent_file = os.path.join( + multi_agent_manager._get_multi_agent_path(sample_session.session_id, mock_multi_agent.id), + "multi_agent.json", + ) + assert os.path.exists(multi_agent_file) + + # Verify content + with open(multi_agent_file, "r") as f: + data = json.load(f) + assert data["id"] == mock_multi_agent.id + assert data["state"] == mock_multi_agent.state + + +def test_read_multi_agent(multi_agent_manager, sample_session, mock_multi_agent): + """Test reading multi-agent state.""" + # Create session and multi-agent + multi_agent_manager.create_session(sample_session) + multi_agent_manager.create_multi_agent(sample_session.session_id, mock_multi_agent) + + # Read multi-agent + result = multi_agent_manager.read_multi_agent(sample_session.session_id, mock_multi_agent.id) + + assert result["id"] == mock_multi_agent.id + assert result["state"] == mock_multi_agent.state + + +def test_read_nonexistent_multi_agent(multi_agent_manager, sample_session): + """Test reading multi-agent state that doesn't exist.""" + result = multi_agent_manager.read_multi_agent(sample_session.session_id, "nonexistent") + assert result is None + + +def test_update_multi_agent(multi_agent_manager, sample_session, mock_multi_agent): + """Test updating multi-agent state.""" + # Create session and multi-agent + multi_agent_manager.create_session(sample_session) + multi_agent_manager.create_multi_agent(sample_session.session_id, mock_multi_agent) + + updated_mock = Mock() + updated_mock.id = mock_multi_agent.id + updated_mock.serialize_state.return_value = {"id": mock_multi_agent.id, "state": {"updated": "value"}} + multi_agent_manager.update_multi_agent(sample_session.session_id, updated_mock) + + # Verify update + result = multi_agent_manager.read_multi_agent(sample_session.session_id, mock_multi_agent.id) + assert result["state"] == {"updated": "value"} + + +def test_update_nonexistent_multi_agent(multi_agent_manager, sample_session): + """Test updating multi-agent state that doesn't exist.""" + # Create session + multi_agent_manager.create_session(sample_session) + + nonexistent_mock = Mock() + nonexistent_mock.id = "nonexistent" + with pytest.raises(SessionException): + multi_agent_manager.update_multi_agent(sample_session.session_id, nonexistent_mock) + + +def test_create_session_multi_agent_directory_structure(multi_agent_manager, sample_session): + """Test multi-agent session creates correct directory structure.""" + multi_agent_manager.create_session(sample_session) + + # Verify directory structure + session_dir = multi_agent_manager._get_session_path(sample_session.session_id) + multi_agents_dir = os.path.join(session_dir, "multi_agents") + + assert os.path.exists(session_dir) + assert os.path.exists(multi_agents_dir) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 923b13daa..e346f01e0 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -1,5 +1,7 @@ """Tests for AgentSessionManager.""" +from unittest.mock import Mock + import pytest from strands.agent.agent import Agent @@ -31,6 +33,17 @@ def agent(): return Agent(messages=[{"role": "user", "content": [{"text": "Hello!"}]}]) +@pytest.fixture +def mock_multi_agent(): + """Create mock multi-agent for testing.""" + + mock = Mock() + mock.id = "test-multi-agent" + mock.serialize_state.return_value = {"id": "test-multi-agent", "state": {"key": "value"}} + mock.deserialize_state = Mock() + return mock + + def test_init_creates_session_if_not_exists(mock_repository): """Test that init creates a session if it doesn't exist.""" # Session doesn't exist yet @@ -177,3 +190,46 @@ def test_append_message(session_manager): assert len(messages) == 1 assert messages[0].message["role"] == "user" assert messages[0].message["content"][0]["text"] == "Hello" + + +def test_sync_multi_agent(session_manager, mock_multi_agent): + """Test syncing multi-agent state.""" + # Create multi-agent first + session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) + + # Sync multi-agent + session_manager.sync_multi_agent(mock_multi_agent) + + # Verify repository update_multi_agent was called + state = session_manager.session_repository.read_multi_agent("test-session", mock_multi_agent.id) + assert state["id"] == "test-multi-agent" + assert state["state"] == {"key": "value"} + + +def test_initialize_multi_agent_new(session_manager, mock_multi_agent): + """Test initializing new multi-agent state.""" + session_manager.initialize_multi_agent(mock_multi_agent) + + # Verify multi-agent was created + state = session_manager.session_repository.read_multi_agent("test-session", mock_multi_agent.id) + assert state["id"] == "test-multi-agent" + assert state["state"] == {"key": "value"} + + +def test_initialize_multi_agent_existing(session_manager, mock_multi_agent): + """Test initializing existing multi-agent state.""" + # Create existing state first + session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) + + # Create a mock with updated state for the update call + updated_mock = Mock() + updated_mock.id = "test-multi-agent" + existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}} + updated_mock.serialize_state.return_value = existing_state + session_manager.session_repository.update_multi_agent("test-session", updated_mock) + + # Initialize multi-agent + session_manager.initialize_multi_agent(mock_multi_agent) + + # Verify deserialize_state was called with existing state + mock_multi_agent.deserialize_state.assert_called_once_with(existing_state) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index c4d6a0154..719fbc2c9 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -1,6 +1,7 @@ """Tests for S3SessionManager.""" import json +from unittest.mock import Mock import boto3 import pytest @@ -374,3 +375,75 @@ def test__get_message_path_invalid_message_id(message_id, s3_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): s3_manager._get_message_path("session1", "agent1", message_id) + + +@pytest.fixture +def mock_multi_agent(): + """Create mock multi-agent for testing.""" + + mock = Mock() + mock.id = "test-multi-agent" + mock.state = {"key": "value"} + mock.serialize_state.return_value = {"id": "test-multi-agent", "state": {"key": "value"}} + return mock + + +def test_create_multi_agent(s3_manager, sample_session, mock_multi_agent): + """Test creating multi-agent state in S3.""" + s3_manager.create_session(sample_session) + s3_manager.create_multi_agent(sample_session.session_id, mock_multi_agent) + + # Verify S3 object created + key = f"{s3_manager._get_multi_agent_path(sample_session.session_id, mock_multi_agent.id)}multi_agent.json" + response = s3_manager.client.get_object(Bucket=s3_manager.bucket, Key=key) + data = json.loads(response["Body"].read().decode("utf-8")) + + assert data["id"] == mock_multi_agent.id + assert data["state"] == mock_multi_agent.state + + +def test_read_multi_agent(s3_manager, sample_session, mock_multi_agent): + """Test reading multi-agent state from S3.""" + # Create session and multi-agent + s3_manager.create_session(sample_session) + s3_manager.create_multi_agent(sample_session.session_id, mock_multi_agent) + + # Read multi-agent + result = s3_manager.read_multi_agent(sample_session.session_id, mock_multi_agent.id) + + assert result["id"] == mock_multi_agent.id + assert result["state"] == mock_multi_agent.state + + +def test_read_nonexistent_multi_agent(s3_manager, sample_session): + """Test reading multi-agent state that doesn't exist.""" + s3_manager.create_session(sample_session) + result = s3_manager.read_multi_agent(sample_session.session_id, "nonexistent") + assert result is None + + +def test_update_multi_agent(s3_manager, sample_session, mock_multi_agent): + """Test updating multi-agent state in S3.""" + # Create session and multi-agent + s3_manager.create_session(sample_session) + s3_manager.create_multi_agent(sample_session.session_id, mock_multi_agent) + + updated_mock = Mock() + updated_mock.id = mock_multi_agent.id + updated_mock.serialize_state.return_value = {"id": mock_multi_agent.id, "state": {"updated": "value"}} + s3_manager.update_multi_agent(sample_session.session_id, updated_mock) + + # Verify update + result = s3_manager.read_multi_agent(sample_session.session_id, mock_multi_agent.id) + assert result["state"] == {"updated": "value"} + + +def test_update_nonexistent_multi_agent(s3_manager, sample_session): + """Test updating multi-agent state that doesn't exist.""" + # Create session + s3_manager.create_session(sample_session) + + nonexistent_mock = Mock() + nonexistent_mock.id = "nonexistent" + with pytest.raises(SessionException): + s3_manager.update_multi_agent(sample_session.session_id, nonexistent_mock)