Skip to content
Open
11 changes: 7 additions & 4 deletions src/strands/experimental/agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

import json
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import jsonschema
from jsonschema import ValidationError

from ..agent import Agent
if TYPE_CHECKING:
from ..agent.agent import Agent

# JSON Schema for agent configuration
AGENT_CONFIG_SCHEMA = {
Expand Down Expand Up @@ -53,7 +54,7 @@
_VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA)


def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Agent:
def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> "Agent":
"""Create an Agent from a configuration file or dictionary.

This function supports tools that can be loaded declaratively (file paths, module names,
Expand Down Expand Up @@ -134,5 +135,7 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A
# Override with any additional kwargs provided
agent_kwargs.update(kwargs)

# Create and return Agent
# Create and return Agent (import at runtime to avoid circular import)
from ..agent.agent import Agent

return Agent(**agent_kwargs)
4 changes: 2 additions & 2 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

multiagent_result = cls(
status=Status(data.get("status", Status.PENDING.value)),
status=Status(data.get("status")),
results=results,
accumulated_usage=usage,
accumulated_metrics=metrics,
Expand Down Expand Up @@ -210,7 +210,7 @@ def serialize_state(self) -> dict[str, Any]:
"""Return a JSON-serializable snapshot of the orchestrator state."""
raise NotImplementedError

def deserialize_state(self, payload: dict[str, Any]) -> None:
def restore_from_session(self, payload: dict[str, Any]) -> None:
"""Restore orchestrator state from a session dict."""
raise NotImplementedError

Expand Down
56 changes: 50 additions & 6 deletions src/strands/session/file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import os
import shutil
import tempfile
from typing import Any, Optional, cast
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, cast

from .. import _identifier
from ..types.exceptions import SessionException
from ..types.session import Session, SessionAgent, SessionMessage
from ..types.session import Session, SessionAgent, SessionMessage, SessionType
from .repository_session_manager import RepositorySessionManager
from .session_repository import SessionRepository

if TYPE_CHECKING:
from ..multiagent.base import MultiAgentBase

logger = logging.getLogger(__name__)

SESSION_PREFIX = "session_"
Expand All @@ -37,19 +41,26 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
```
"""

def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any):
def __init__(
self,
session_id: str,
storage_dir: Optional[str] = None,
session_type: SessionType = SessionType.AGENT,
**kwargs: Any,
):
"""Initialize FileSession with filesystem storage.

Args:
session_id: ID for the session.
ID is not allowed to contain path separators (e.g., a/b).
storage_dir: Directory for local filesystem storage (defaults to temp dir).
session_type: single agent or multiagent.
**kwargs: Additional keyword arguments for future extensibility.
"""
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
os.makedirs(self.storage_dir, exist_ok=True)

super().__init__(session_id=session_id, session_repository=self)
super().__init__(session_id=session_id, session_repository=self, session_type=session_type)

def _get_session_path(self, session_id: str) -> str:
"""Get session directory path.
Expand Down Expand Up @@ -107,8 +118,10 @@ def _read_file(self, path: str) -> dict[str, Any]:
def _write_file(self, path: str, data: dict[str, Any]) -> None:
"""Write JSON file."""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
tmp = f"{path}.tmp"
with open(tmp, "w", encoding="utf-8", newline="\n") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
os.replace(tmp, path)

def create_session(self, session: Session, **kwargs: Any) -> Session:
"""Create a new session."""
Expand All @@ -118,7 +131,8 @@ def create_session(self, session: Session, **kwargs: Any) -> Session:

# Create directory structure
os.makedirs(session_dir, exist_ok=True)
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)
if self.session_type == SessionType.AGENT:
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)

# Write session file
session_file = os.path.join(session_dir, "session.json")
Expand Down Expand Up @@ -239,3 +253,33 @@ def list_messages(
messages.append(SessionMessage.from_dict(message_data))

return messages

def write_multi_agent_json(self, source: "MultiAgentBase") -> None:
"""Write multi-agent state to filesystem.

Args:
source: Multi-agent source object to persist
"""
state = source.serialize_state()
state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json")
self._write_file(state_path, state)

# Update session metadata
session_dir = self._get_session_path(self.session.session_id)
session_file = os.path.join(session_dir, "session.json")
with open(session_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
metadata["updated_at"] = datetime.now(timezone.utc).isoformat()
self._write_file(session_file, metadata)

def read_multi_agent_json(self) -> dict[str, Any]:
"""Read multi-agent state from filesystem.

Returns:
Multi-agent state dictionary or empty dict if not found
"""
state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json")
if not os.path.exists(state_path):
return {}
state_data = self._read_file(state_path)
return state_data
17 changes: 14 additions & 3 deletions src/strands/session/repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
class RepositorySessionManager(SessionManager):
"""Session manager for persisting agents in a SessionRepository."""

def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any):
def __init__(
self,
session_id: str,
session_repository: SessionRepository,
session_type: SessionType = SessionType.AGENT,
**kwargs: Any,
):
"""Initialize the RepositorySessionManager.

If no session with the specified session_id exists yet, it will be created
Expand All @@ -34,22 +40,27 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa
session_id: ID to use for the session. A new session with this id will be created if it does
not exist in the repository yet
session_repository: Underlying session repository to use to store the sessions state.
session_type: single agent or multiagent.
**kwargs: Additional keyword arguments for future extensibility.

"""
super().__init__(session_type=session_type)

self.session_repository = session_repository
self.session_id = session_id
session = session_repository.read_session(session_id)
# Create a session if it does not exist yet
if session is None:
logger.debug("session_id=<%s> | session not found, creating new session", self.session_id)
session = Session(session_id=session_id, session_type=SessionType.AGENT)
session = Session(session_id=session_id, session_type=session_type)
session_repository.create_session(session)

self.session = session
self.session_type = session.session_type

# Keep track of the latest message of each agent in case we need to redact it.
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}
if self.session_type == SessionType.AGENT:
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}

def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
"""Append a message to the agent's session.
Expand Down
32 changes: 29 additions & 3 deletions src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

import json
import logging
from typing import Any, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

import boto3
from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError

from .. import _identifier
from ..types.exceptions import SessionException
from ..types.session import Session, SessionAgent, SessionMessage
from ..types.session import Session, SessionAgent, SessionMessage, SessionType
from .repository_session_manager import RepositorySessionManager
from .session_repository import SessionRepository

if TYPE_CHECKING:
from ..multiagent.base import MultiAgentBase

logger = logging.getLogger(__name__)

SESSION_PREFIX = "session_"
Expand Down Expand Up @@ -46,6 +49,7 @@ def __init__(
boto_session: Optional[boto3.Session] = None,
boto_client_config: Optional[BotocoreConfig] = None,
region_name: Optional[str] = None,
session_type: SessionType = SessionType.AGENT,
**kwargs: Any,
):
"""Initialize S3SessionManager with S3 storage.
Expand All @@ -58,6 +62,7 @@ def __init__(
boto_session: Optional boto3 session
boto_client_config: Optional boto3 client configuration
region_name: AWS region for S3 storage
session_type: single agent or multiagent.
**kwargs: Additional keyword arguments for future extensibility.
"""
self.bucket = bucket
Expand All @@ -78,7 +83,7 @@ def __init__(
client_config = BotocoreConfig(user_agent_extra="strands-agents")

self.client = session.client(service_name="s3", config=client_config)
super().__init__(session_id=session_id, session_repository=self)
super().__init__(session_id=session_id, session_type=session_type, session_repository=self)

def _get_session_path(self, session_id: str) -> str:
"""Get session S3 prefix.
Expand Down Expand Up @@ -294,3 +299,24 @@ def list_messages(

except ClientError as e:
raise SessionException(f"S3 error reading messages: {e}") from e

def write_multi_agent_json(self, source: "MultiAgentBase") -> None:
"""Write multi-agent state to S3.

Args:
source: Multi-agent source object to persist
"""
session_prefix = self._get_session_path(self.session_id)
state_key = f"{session_prefix}multi_agent_state.json"
state = source.serialize_state()
self._write_s3_object(state_key, state)

def read_multi_agent_json(self) -> dict[str, Any]:
"""Read multi-agent state from S3.

Returns:
Multi-agent state dictionary or empty dict if not found
"""
session_prefix = self._get_session_path(self.session_id)
state_key = f"{session_prefix}multi_agent_state.json"
return self._read_s3_object(state_key) or {}
84 changes: 76 additions & 8 deletions src/strands/session/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
"""Session manager interface for agent session management."""

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from ..experimental.hooks.multiagent.events import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
MultiAgentInitializedEvent,
)
from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
from ..hooks.registry import HookProvider, HookRegistry
from ..types.content import Message
from ..types.session import SessionType

if TYPE_CHECKING:
from ..agent.agent import Agent
from ..multiagent.base import MultiAgentBase

logger = logging.getLogger(__name__)


class SessionManager(HookProvider, ABC):
Expand All @@ -20,19 +30,39 @@ class SessionManager(HookProvider, ABC):
for an agent, and should be persisted in the session.
"""

def __init__(self, session_type: SessionType = SessionType.AGENT) -> None:
"""Initialize SessionManager with session type.

Args:
session_type: Type of session (AGENT or MULTI_AGENT)
"""
self.session_type: SessionType = session_type

def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register hooks for persisting the agent to the session."""
# After the normal Agent initialization behavior, call the session initialize function to restore the agent
registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))
if not hasattr(self, "session_type"):
self.session_type = SessionType.AGENT
logger.debug("Session type not set, defaulting to AGENT")

# For each message appended to the Agents messages, store that message in the session
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))
if self.session_type == SessionType.MULTI_AGENT:
registry.add_callback(MultiAgentInitializedEvent, self._on_multiagent_initialized)
registry.add_callback(AfterNodeCallEvent, lambda event: self.write_multi_agent_json(event.source))
registry.add_callback(
AfterMultiAgentInvocationEvent, lambda event: self.write_multi_agent_json(event.source)
)

# Sync the agent into the session for each message in case the agent state was updated
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))
else:
# After the normal Agent initialization behavior, call the session initialize function to restore the agent
registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))

# After an agent was invoked, sync it with the session to capture any conversation manager state updates
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
# For each message appended to the Agents messages, store that message in the session
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))

# Sync the agent into the session for each message in case the agent state was updated
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))

# After an agent was invoked, sync it with the session to capture any conversation manager state updates
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))

@abstractmethod
def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None:
Expand Down Expand Up @@ -71,3 +101,41 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
agent: Agent to initialize
**kwargs: Additional keyword arguments for future extensibility.
"""

def write_multi_agent_json(self, source: "MultiAgentBase") -> None:
"""Write multi-agent state to persistent storage.

Args:
source: Multi-agent source object to persist
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support multi-agent persistence "
"(write_multi_agent_json). Provide an implementation or use a "
"SessionManager with session_type=SessionType.MULTI_AGENT."
)

def read_multi_agent_json(self) -> dict[str, Any]:
"""Read multi-agent state from persistent storage.

Returns:
Multi-agent state dictionary or empty dict if not found
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support multi-agent persistence "
"(read_multi_agent_json). Provide an implementation or use a "
"SessionManager with session_type=SessionType.MULTI_AGENT."
)

def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None:
"""Handle multi-agent initialization: restore from storage or create initial snapshot.

If existing state is found, deserializes it into the source. Otherwise,
persists the current state as the initial snapshot.
"""
source: MultiAgentBase = event.source
payload = self.read_multi_agent_json()
# payload can be {} or Graph/Swarm state json
if payload:
source.restore_from_session(payload)
else:
self.write_multi_agent_json(source)
1 change: 1 addition & 0 deletions src/strands/types/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class SessionType(str, Enum):
"""

AGENT = "AGENT"
MULTI_AGENT = "MULTI_AGENT"


def encode_bytes_values(obj: Any) -> Any:
Expand Down
Loading
Loading