diff --git a/.gitignore b/.gitignore index bd7c9a4..1a431b0 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,6 @@ venv notifications repl_state slack_events -tasks \ No newline at end of file +tasks +.kiro +.strands* \ No newline at end of file diff --git a/README.md b/README.md index 19174d7..d86784d 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ strands --kb YOUR_KB_ID "Load my previous calculator tool and enhance it with sc - 🪄 Nested agent capabilities with tool delegation - 🔧 Dynamic tool loading for extending functionality - 🖥️ Environment variable management and customization +- 💾 **Session Management**: Optional state persistence and conversation resumption ## Integrated Tools @@ -98,6 +99,35 @@ Strands comes with a comprehensive set of built-in tools: - **welcome**: Manage the Strands Agent Builder welcome text - **workflow**: Orchestrate sequenced workflows +## Session Management + +Strands Agent Builder includes optional session management that automatically saves your conversation state and allows you to resume conversations exactly where you left off. + +```bash +# Enable session management with a custom path +strands --session-path /path/to/sessions "Create a complex agent" +# Output: Created new session: strands-1234567890-abc123 + +# Or use environment variable +export STRANDS_SESSION_PATH=/path/to/sessions +strands "Create a complex agent" + +# Resume any session exactly where you left off +strands --session-path /path/to/sessions --session-id strands-1234567890-abc123 + +# List all your saved sessions +strands --session-path /path/to/sessions --list-sessions + +# Interactive session commands (when session management is enabled) +!session info # Show current session details +!session list # List all sessions +``` + +**Configuration:** +- Use `--session-path ` to specify where sessions are stored +- Or set `STRANDS_SESSION_PATH` environment variable +- Sessions are only enabled when a path is provided + ## Knowledge Base Integration Strands Agent Builder leverages Amazon Bedrock Knowledge Bases to store and retrieve custom tools, agent configurations, and development history. diff --git a/src/strands_agents_builder/strands.py b/src/strands_agents_builder/strands.py index ab64c2e..c02970b 100644 --- a/src/strands_agents_builder/strands.py +++ b/src/strands_agents_builder/strands.py @@ -5,6 +5,7 @@ import argparse import os +from typing import Optional # Strands from strands import Agent @@ -14,11 +15,115 @@ from strands_agents_builder.tools import get_tools from strands_agents_builder.utils import model_utils from strands_agents_builder.utils.kb_utils import load_system_prompt, store_conversation_in_kb +from strands_agents_builder.utils.session_utils import ( + console, + display_agent_history, + handle_session_commands, + list_sessions_command, + setup_session_management, +) from strands_agents_builder.utils.welcome_utils import render_goodbye_message, render_welcome_message os.environ["STRANDS_TOOL_CONSOLE_MODE"] = "enabled" +def execute_command_mode(agent: Agent, query: str, knowledge_base_id: Optional[str]) -> None: + """Execute a single query in command mode.""" + # Use retrieve if knowledge_base_id is defined + if knowledge_base_id: + agent.tool.retrieve(text=query, knowledgeBaseId=knowledge_base_id) + + agent(query) + + if knowledge_base_id: + # Store conversation in knowledge base + store_conversation_in_kb(agent, query, knowledge_base_id) + + +def handle_shell_command(agent: Agent, command: str, user_input: str) -> None: + """Handle shell command execution.""" + print(f"$ {command}") + try: + # Execute shell command directly using the shell tool + agent.tool.shell( + command=command, + user_message_override=user_input, + non_interactive_mode=True, + ) + except Exception as e: + console.print(f"[red]Error: {str(e)}[/red]") + + +def execute_interactive_mode( + agent: Agent, + knowledge_base_id: Optional[str], + session_id: Optional[str], + session_base_path: Optional[str], + session_manager=None, + is_resuming: bool = False, +) -> None: + """Execute interactive mode with conversation loop.""" + # Display welcome text at startup + welcome_result = agent.tool.welcome(action="view", record_direct_tool_call=False) + welcome_text = "" + if welcome_result["status"] == "success": + welcome_text = welcome_result["content"][0]["text"] + render_welcome_message(welcome_text) + + # Display session history if resuming (after welcome message) + if is_resuming and session_manager and session_id: + display_agent_history(agent, session_id) + + while True: + try: + user_input = get_user_input("\n~ ", default="", keyboard_interrupt_return_default=False) + + if user_input.lower() in ["exit", "quit"]: + render_goodbye_message() + break + + if user_input.startswith("!"): + command = user_input[1:].strip() # Remove the ! prefix + + # Handle session management commands + if handle_session_commands(command, session_id, session_base_path): + continue + + # Handle regular shell commands + handle_shell_command(agent, command, user_input) + continue + + if user_input.strip(): + # Use retrieve if knowledge_base_id is defined + if knowledge_base_id: + agent.tool.retrieve(text=user_input, knowledgeBaseId=knowledge_base_id) + + # Read welcome text and add it to the system prompt + welcome_result = agent.tool.welcome(action="view", record_direct_tool_call=False) + base_system_prompt = load_system_prompt() + welcome_text = "" + + if welcome_result["status"] == "success": + # Combine welcome text with base system prompt + welcome_text = welcome_result["content"][0]["text"] + agent.system_prompt = f"{base_system_prompt}\n\nWelcome Text Reference:\n{welcome_text}" + else: + agent.system_prompt = base_system_prompt + + response = agent(user_input) + + if knowledge_base_id: + # Store conversation in knowledge base + store_conversation_in_kb(agent, user_input, response, knowledge_base_id) + + except (KeyboardInterrupt, EOFError): + render_goodbye_message() + break + except Exception as e: + callback_handler(force_stop=True) # Stop spinners + console.print(f"[red]Error: {str(e)}[/red]") + + def main(): # Parse command line arguments parser = argparse.ArgumentParser(description="Strands - A minimal CLI interface for Strands") @@ -41,95 +146,57 @@ def main(): default="{}", help="Model config as JSON string or path", ) + # Session management arguments + parser.add_argument( + "--session-id", + help="Session ID to use or resume", + ) + parser.add_argument( + "--session-path", + help="Base path for session storage (enables session management)", + ) + parser.add_argument( + "--list-sessions", + action="store_true", + help="List all available sessions and exit", + ) args = parser.parse_args() + # Get session base path from args or environment variable + session_base_path = args.session_path or os.getenv("STRANDS_SESSION_PATH") + + # Handle session listing command + if args.list_sessions: + list_sessions_command(session_base_path) + return + # Get knowledge_base_id from args or environment variable knowledge_base_id = args.knowledge_base_id or os.getenv("STRANDS_KNOWLEDGE_BASE_ID") + # Load model and tools model = model_utils.load_model(args.model_provider, args.model_config) - - # Load system prompt system_prompt = load_system_prompt() - tools = get_tools().values() + # Set up session management + session_manager, session_id, is_resuming = setup_session_management(args.session_id, session_base_path) + + # Create agent agent = Agent( model=model, tools=tools, system_prompt=system_prompt, callback_handler=callback_handler, load_tools_from_directory=True, + session_manager=session_manager, ) - # Process query or enter interactive mode + # Execute command mode or interactive mode if args.query: query = " ".join(args.query) - # Use retrieve if knowledge_base_id is defined - if knowledge_base_id: - agent.tool.retrieve(text=query, knowledgeBaseId=knowledge_base_id) - - agent(query) - - if knowledge_base_id: - # Store conversation in knowledge base - store_conversation_in_kb(agent, query, knowledge_base_id) + execute_command_mode(agent, query, knowledge_base_id) else: - # Display welcome text at startup - welcome_result = agent.tool.welcome(action="view", record_direct_tool_call=False) - welcome_text = "" - if welcome_result["status"] == "success": - welcome_text = welcome_result["content"][0]["text"] - render_welcome_message(welcome_text) - while True: - try: - user_input = get_user_input("\n~ ", default="", keyboard_interrupt_return_default=False) - if user_input.lower() in ["exit", "quit"]: - render_goodbye_message() - break - if user_input.startswith("!"): - shell_command = user_input[1:] # Remove the ! prefix - print(f"$ {shell_command}") - - try: - # Execute shell command directly using the shell tool - agent.tool.shell( - command=shell_command, - user_message_override=user_input, - non_interactive_mode=True, - ) - - print() # new line after shell command execution - except Exception as e: - print(f"Shell command execution error: {str(e)}") - continue - - if user_input.strip(): - # Use retrieve if knowledge_base_id is defined - if knowledge_base_id: - agent.tool.retrieve(text=user_input, knowledgeBaseId=knowledge_base_id) - # Read welcome text and add it to the system prompt - welcome_result = agent.tool.welcome(action="view", record_direct_tool_call=False) - base_system_prompt = load_system_prompt() - welcome_text = "" - - if welcome_result["status"] == "success": - # Combine welcome text with base system prompt - welcome_text = welcome_result["content"][0]["text"] - agent.system_prompt = f"{base_system_prompt}\n\nWelcome Text Reference:\n{welcome_text}" - else: - agent.system_prompt = base_system_prompt - - response = agent(user_input) - - if knowledge_base_id: - # Store conversation in knowledge base - store_conversation_in_kb(agent, user_input, response, knowledge_base_id) - except (KeyboardInterrupt, EOFError): - render_goodbye_message() - break - except Exception as e: - callback_handler(force_stop=True) # Stop spinners - print(f"\nError: {str(e)}") + execute_interactive_mode(agent, knowledge_base_id, session_id, session_base_path, session_manager, is_resuming) if __name__ == "__main__": diff --git a/src/strands_agents_builder/tools.py b/src/strands_agents_builder/tools.py index 52e783e..2548ecd 100644 --- a/src/strands_agents_builder/tools.py +++ b/src/strands_agents_builder/tools.py @@ -16,6 +16,7 @@ image_reader, journal, load_tool, + mcp_client, memory, nova_reels, retrieve, @@ -59,6 +60,7 @@ def get_tools() -> dict[str, Any]: "image_reader": image_reader, "journal": journal, "load_tool": load_tool, + "mcp_client": mcp_client, "memory": memory, "nova_reels": nova_reels, "retrieve": retrieve, diff --git a/src/strands_agents_builder/utils/session_utils.py b/src/strands_agents_builder/utils/session_utils.py new file mode 100644 index 0000000..875fcfa --- /dev/null +++ b/src/strands_agents_builder/utils/session_utils.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +""" +Session management utilities for Strands Agent Builder. +""" + +import datetime +import logging +import time +import uuid +from pathlib import Path +from typing import Optional, Tuple + +from colorama import Fore, Style +from rich.align import Align +from rich.box import ROUNDED +from rich.console import Console +from rich.panel import Panel +from strands.session.file_session_manager import FileSessionManager + +# Create console for rich formatting +console = Console() + +# Constants +SESSION_PREFIX = "session_" +DEFAULT_DISPLAY_LIMIT = 10 + +# Set up logging +logger = logging.getLogger(__name__) + + +def validate_session_id(session_id: str) -> bool: + """Validate that a session ID is safe to use as a directory name.""" + if not session_id: + return False + + # Check for basic safety - no path separators, no hidden files, reasonable length + if any(char in session_id for char in ["/", "\\", "..", "\0"]): + return False + + if session_id.startswith(".") or len(session_id) > 255: + return False + + return True + + +def validate_session_path(path: str) -> bool: + """Validate that a session path is safe to use.""" + if not path: + return False + + try: + # Try to create a Path object and check if it's absolute or relative + path_obj = Path(path) + # Basic validation - path should be reasonable + return len(str(path_obj)) < 4096 # Reasonable path length limit + except (ValueError, OSError): + return False + + +def generate_session_id() -> str: + """Generate a unique session ID based on timestamp and UUID.""" + timestamp = int(time.time()) + short_uuid = str(uuid.uuid4())[:8] + return f"strands-{timestamp}-{short_uuid}" + + +def get_sessions_directory(base_path: Optional[str] = None, create: bool = False) -> Optional[Path]: + """Get the sessions directory path. Returns None if no base_path provided.""" + if not base_path: + return None + + sessions_dir = Path(base_path) + + if create: + sessions_dir.mkdir(parents=True, exist_ok=True) + + return sessions_dir + + +def create_session_manager( + session_id: Optional[str] = None, base_path: Optional[str] = None +) -> Optional[FileSessionManager]: + """Create a FileSessionManager with the given or generated session ID. Returns None if no base_path.""" + if not base_path or not validate_session_path(base_path): + return None + + if session_id is None: + session_id = generate_session_id() + elif not validate_session_id(session_id): + logger.warning(f"Invalid session ID provided: {session_id}") + return None + + # Create the sessions directory since we're actually creating a session manager + sessions_dir = get_sessions_directory(base_path, create=True) + return FileSessionManager(session_id=session_id, storage_dir=str(sessions_dir)) + + +def list_available_sessions(base_path: Optional[str] = None) -> list[str]: + """List all available session IDs in the sessions directory.""" + if not base_path or not validate_session_path(base_path): + return [] + + # Don't create directory, just check if it exists + sessions_dir = get_sessions_directory(base_path, create=False) + session_ids = [] + + if sessions_dir and sessions_dir.exists(): + try: + for session_dir in sessions_dir.iterdir(): + if session_dir.is_dir() and session_dir.name.startswith(SESSION_PREFIX): + # Extract session ID from directory name (remove "session_" prefix) + session_id = session_dir.name[len(SESSION_PREFIX) :] + if validate_session_id(session_id): + session_ids.append(session_id) + except (OSError, PermissionError) as e: + logger.warning(f"Failed to list sessions in {base_path}: {e}") + + return sorted(session_ids) + + +def session_exists(session_id: str, base_path: Optional[str] = None) -> bool: + """Check if a session exists.""" + if not base_path or not validate_session_path(base_path) or not validate_session_id(session_id): + return False + + # Don't create directory, just check if session exists + sessions_dir = get_sessions_directory(base_path, create=False) + if not sessions_dir: + return False + + session_dir = sessions_dir / f"{SESSION_PREFIX}{session_id}" + return session_dir.exists() and (session_dir / "session.json").exists() + + +def get_session_info(session_id: str, base_path: Optional[str] = None) -> Optional[dict]: + """Get basic information about a session.""" + if not base_path or not validate_session_path(base_path) or not validate_session_id(session_id): + return None + + if not session_exists(session_id, base_path): + return None + + # Don't create directory, just get the path + sessions_dir = get_sessions_directory(base_path, create=False) + if not sessions_dir: + return None + + session_dir = sessions_dir / f"{SESSION_PREFIX}{session_id}" + + try: + # Get creation time from directory + created_at = session_dir.stat().st_ctime + + # Count messages across all agents + total_messages = 0 + agents_dir = session_dir / "agents" + if agents_dir.exists(): + for agent_dir in agents_dir.iterdir(): + if agent_dir.is_dir(): + messages_dir = agent_dir / "messages" + if messages_dir.exists(): + total_messages += len( + [f for f in messages_dir.iterdir() if f.is_file() and f.suffix == ".json"] + ) + + return { + "session_id": session_id, + "created_at": created_at, + "total_messages": total_messages, + "path": str(session_dir), + } + except (OSError, PermissionError) as e: + logger.warning(f"Failed to get session info for {session_id}: {e}") + return None + + +def list_sessions_command(session_base_path: Optional[str]) -> None: + """Handle the --list-sessions command.""" + if not session_base_path: + console.print( + "[red]Error: Session management not enabled. Use --session-path or " + "set STRANDS_SESSION_PATH environment variable.[/red]" + ) + return + + sessions = list_available_sessions(session_base_path) + if not sessions: + console.print("[yellow]No sessions found.[/yellow]") + else: + console.print("[bold cyan]Available sessions:[/bold cyan]") + for session_id in sessions: + info = get_session_info(session_id, session_base_path) + if info: + created = datetime.datetime.fromtimestamp(info["created_at"]).strftime("%Y-%m-%d %H:%M:%S") + console.print(f" [green]{session_id}[/green] (created: {created}, messages: {info['total_messages']})") + + +def display_agent_history(agent, session_id: str) -> None: + """Display conversation history from an agent's loaded messages.""" + try: + if agent.messages and len(agent.messages) > 0: + # Display last messages completely + display_limit = DEFAULT_DISPLAY_LIMIT + + # Create header message + header_text = f"Resuming session: {session_id}" + + # Show indicator if there are more messages + if len(agent.messages) > display_limit: + hidden_count = len(agent.messages) - display_limit + subtitle_text = f"{hidden_count} previous messages not shown" + else: + subtitle_text = f"Showing all {len(agent.messages)} messages" + + # Display header with rich formatting + header_panel = Panel( + Align.center(f"[bold cyan]{header_text}[/bold cyan]"), + subtitle=f"[dim]{subtitle_text}[/dim]", + border_style="blue", + box=ROUNDED, + expand=False, + padding=(1, 3), + ) + console.print() # Empty line before + console.print(header_panel) + console.print() # Empty line after + + # Get the messages to display + recent_messages = agent.messages[-display_limit:] if len(agent.messages) > display_limit else agent.messages + + for msg in recent_messages: + role = msg.get("role", "unknown") + content_blocks = msg.get("content", []) + + # Extract text from content blocks + content = "" + for block in content_blocks: + if isinstance(block, dict) and "text" in block: + content += block["text"] + + if role == "user": + print(f"{Fore.GREEN}~ {Style.RESET_ALL}{content}") + print() # Empty line after user message + elif role == "assistant": + print(f"{Fore.WHITE}{content}{Style.RESET_ALL}") + print() # Empty line after assistant message + + except Exception as e: + # If we can't load history, log the error but continue + logger.warning(f"Failed to display agent history for session {session_id}: {e}") + console.print("[yellow]Warning: Could not load session history[/yellow]") + + +def setup_session_management( + session_id: Optional[str], session_base_path: Optional[str] +) -> Tuple[Optional[FileSessionManager], Optional[str], bool]: + """Set up session management if enabled. Returns (session_manager, session_id, is_resuming).""" + session_manager = None + resolved_session_id = None + is_resuming = False + + if session_base_path: + # Check if resuming existing session + if session_id: + if session_exists(session_id, session_base_path): + resolved_session_id = session_id + is_resuming = True + else: + resolved_session_id = session_id + + # Create session manager + session_manager = create_session_manager(resolved_session_id, session_base_path) + if session_manager: + resolved_session_id = session_manager.session_id + + return session_manager, resolved_session_id, is_resuming + + +def handle_session_commands(command: str, session_id: Optional[str], session_base_path: Optional[str]) -> bool: + """Handle session-related interactive commands. Returns True if command was handled.""" + if command == "session info" and session_id: + info = get_session_info(session_id, session_base_path) + if info: + created = datetime.datetime.fromtimestamp(info["created_at"]).strftime("%Y-%m-%d %H:%M:%S") + console.print(f"[bold cyan]Session ID:[/bold cyan] {info['session_id']}") + console.print(f"[bold cyan]Created:[/bold cyan] {created}") + console.print(f"[bold cyan]Total messages:[/bold cyan] {info['total_messages']}") + return True + + elif command == "session list" and session_base_path: + sessions = list_available_sessions(session_base_path) + if not sessions: + console.print("[yellow]No sessions found.[/yellow]") + else: + console.print("[bold cyan]Available sessions:[/bold cyan]") + for sid in sessions: + info = get_session_info(sid, session_base_path) + if info: + created = datetime.datetime.fromtimestamp(info["created_at"]).strftime("%Y-%m-%d %H:%M:%S") + current = " [dim](current)[/dim]" if sid == session_id else "" + console.print( + f" [green]{sid}[/green] (created: {created}, messages: {info['total_messages']}){current}" + ) + return True + + elif command.startswith("session "): + if session_base_path: + console.print("[bold cyan]Available session commands:[/bold cyan]") + console.print(" [green]!session info[/green] - Show current session details") + console.print(" [green]!session list[/green] - List all available sessions") + else: + console.print("[red]Error: Session management not enabled.[/red]") + return True + + return False diff --git a/tests/conftest.py b/tests/conftest.py index 5075399..cc3f26b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -105,3 +105,39 @@ def _create_file(content, name=".prompt"): return file_path return _create_file + + +@pytest.fixture +def mock_session_manager(): + """ + Fixture to mock the session manager + """ + with mock.patch("strands_agents_builder.utils.session_utils.FileSessionManager") as mock_manager_class: + mock_manager_instance = mock.MagicMock() + mock_manager_class.return_value = mock_manager_instance + + # Set up default return values + mock_manager_instance.list_sessions.return_value = [] + mock_manager_instance.save_session.return_value = True + mock_manager_instance.load_session.return_value = {"messages": []} + + yield mock_manager_instance + + +@pytest.fixture +def temp_session_dir(tmp_path): + """ + Fixture to create a temporary session directory + """ + session_dir = tmp_path / "test_sessions" + session_dir.mkdir() + return session_dir + + +@pytest.fixture(autouse=True) +def suppress_console_output(): + """ + Fixture to suppress console output during tests + """ + # No longer needed since we removed console utils abstraction + yield diff --git a/tests/test_strands.py b/tests/test_strands.py index 8402303..eb5fa90 100644 --- a/tests/test_strands.py +++ b/tests/test_strands.py @@ -10,6 +10,11 @@ import pytest from strands_agents_builder import strands +from strands_agents_builder.utils.session_utils import ( + handle_session_commands, + list_sessions_command, + setup_session_management, +) class TestInteractiveMode: @@ -167,11 +172,11 @@ def test_eof_error_exception(self, mock_goodbye, mock_agent, mock_input): # Verify goodbye message was called mock_goodbye.assert_called_once() + @mock.patch("strands_agents_builder.utils.session_utils.console.print") @mock.patch.object(strands, "get_user_input") @mock.patch.object(strands, "Agent") - @mock.patch.object(strands, "print") @mock.patch.object(strands, "callback_handler") - def test_general_exception_handling(self, mock_callback_handler, mock_print, mock_agent, mock_input): + def test_general_exception_handling(self, mock_callback_handler, mock_agent, mock_input, mock_console_print): """Test handling of general exceptions in interactive mode""" # Setup mocks mock_agent_instance = mock.MagicMock() @@ -188,8 +193,8 @@ def test_general_exception_handling(self, mock_callback_handler, mock_print, moc with mock.patch.object(sys, "argv", ["strands"]), mock.patch.object(strands, "render_goodbye_message"): strands.main() - # Verify error was printed - mock_print.assert_any_call("\nError: Test error") + # Verify error was called + mock_console_print.assert_any_call("[red]Error: Test error[/red]") # Verify callback_handler was called to stop spinners mock_callback_handler.assert_called_once_with(force_stop=True) @@ -309,9 +314,16 @@ def test_general_exception(self, mock_agent, mock_bedrock, mock_load_prompt, mon class TestShellCommandError: """Test shell command error handling""" - @mock.patch("builtins.print") + @mock.patch("strands_agents_builder.utils.session_utils.console.print") def test_shell_command_exception( - self, mock_print, mock_agent, mock_bedrock, mock_load_prompt, mock_user_input, mock_welcome_message, monkeypatch + self, + mock_console_print, + mock_agent, + mock_bedrock, + mock_load_prompt, + mock_user_input, + mock_welcome_message, + monkeypatch, ): """Test handling exceptions when executing shell commands""" # Setup mocks @@ -328,8 +340,8 @@ def test_shell_command_exception( with mock.patch.object(strands, "render_goodbye_message"): strands.main() - # Verify error was printed - mock_print.assert_any_call("Shell command execution error: Shell command failed") + # Verify error was called + mock_console_print.assert_any_call("[red]Error: Shell command failed[/red]") class TestKnowledgeBaseIntegration: @@ -428,3 +440,330 @@ def test_welcome_message_failure( # Verify agent was called with system prompt that excludes welcome text reference assert mock_agent.system_prompt == base_system_prompt + + +class TestSessionManagement: + """Test cases for session management functionality""" + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + @mock.patch("strands_agents_builder.utils.session_utils.list_available_sessions") + def test_list_sessions_command_no_sessions(self, mock_list_sessions, mock_console_print): + """Test list-sessions command when no sessions exist""" + # Setup mocks + mock_list_sessions.return_value = [] + + # Mock sys.argv with session path + with mock.patch.object(sys, "argv", ["strands", "--list-sessions", "--session-path", "/tmp/sessions"]): + strands.main() + + # Verify appropriate message was called + mock_console_print.assert_any_call("[yellow]No sessions found.[/yellow]") + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + @mock.patch("strands_agents_builder.utils.session_utils.get_session_info") + @mock.patch("strands_agents_builder.utils.session_utils.list_available_sessions") + def test_list_sessions_command_with_sessions(self, mock_list_sessions, mock_get_info, mock_console_print): + """Test list-sessions command when sessions exist""" + # Setup mocks + mock_list_sessions.return_value = ["session1", "session2"] + mock_get_info.side_effect = [ + {"session_id": "session1", "created_at": 1234567890, "total_messages": 5}, + {"session_id": "session2", "created_at": 1234567891, "total_messages": 3}, + ] + + # Mock sys.argv with session path + with mock.patch.object(sys, "argv", ["strands", "--list-sessions", "--session-path", "/tmp/sessions"]): + strands.main() + + # Verify sessions were listed + mock_console_print.assert_any_call("[bold cyan]Available sessions:[/bold cyan]") + # Check that session info was called for each session + mock_get_info.assert_any_call("session1", "/tmp/sessions") + mock_get_info.assert_any_call("session2", "/tmp/sessions") + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + def test_list_sessions_command_no_base_path(self, mock_console_print): + """Test list-sessions command when no session path is configured""" + # Mock sys.argv without session path + with mock.patch.object(sys, "argv", ["strands", "--list-sessions"]): + strands.main() + + # Verify appropriate error message was called + mock_console_print.assert_called_with( + "[red]Error: Session management not enabled. Use --session-path or " + "set STRANDS_SESSION_PATH environment variable.[/red]" + ) + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + def test_session_management_setup_with_path(self, mock_create_manager, mock_agent, mock_bedrock, mock_load_prompt): + """Test session management setup when session path is provided""" + # Setup mocks + mock_manager = mock.MagicMock() + mock_manager.session_id = "test-session-123" + mock_create_manager.return_value = mock_manager + + # Mock sys.argv with session path + with mock.patch.object(sys, "argv", ["strands", "--session-path", "/tmp/test_sessions", "test", "query"]): + strands.main() + + # Verify session manager was created + mock_create_manager.assert_called_once_with(None, "/tmp/test_sessions") + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + def test_session_management_setup_with_env_var( + self, mock_create_manager, mock_agent, mock_bedrock, mock_load_prompt + ): + """Test session management setup when environment variable is set""" + # Setup mocks + mock_manager = mock.MagicMock() + mock_manager.session_id = "test-session-456" + mock_create_manager.return_value = mock_manager + + # Mock environment variable and sys.argv + with mock.patch.dict(os.environ, {"STRANDS_SESSION_PATH": "/tmp/env_sessions"}): + with mock.patch.object(sys, "argv", ["strands", "test", "query"]): + strands.main() + + # Verify session manager was created + mock_create_manager.assert_called_once_with(None, "/tmp/env_sessions") + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + def test_session_management_no_setup_when_no_path( + self, mock_create_manager, mock_agent, mock_bedrock, mock_load_prompt + ): + """Test that session management is not set up when no path is provided""" + # Mock sys.argv without session path + with mock.patch.object(sys, "argv", ["strands", "test", "query"]): + strands.main() + + # Verify session manager was not created + mock_create_manager.assert_not_called() + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + def test_agent_creation_with_session_manager(self, mock_create_manager, mock_bedrock, mock_load_prompt): + """Test that agent is created with session manager when available""" + # Setup mocks + mock_manager = mock.MagicMock() + mock_manager.session_id = "test-session-789" + mock_create_manager.return_value = mock_manager + + with mock.patch.object(strands, "Agent") as mock_agent_class: + mock_agent_instance = mock.MagicMock() + mock_agent_class.return_value = mock_agent_instance + + # Mock sys.argv with session path + with mock.patch.object(sys, "argv", ["strands", "--session-path", "/tmp/test_sessions", "test", "query"]): + strands.main() + + # Verify agent was created with session manager + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args[1] + assert "session_manager" in call_kwargs + assert call_kwargs["session_manager"] == mock_manager + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + def test_agent_creation_without_session_manager(self, mock_create_manager, mock_bedrock, mock_load_prompt): + """Test that agent is created without session manager when not available""" + # Setup mocks - no session manager created + mock_create_manager.return_value = None + + with mock.patch.object(strands, "Agent") as mock_agent_class: + mock_agent_instance = mock.MagicMock() + mock_agent_class.return_value = mock_agent_instance + + # Mock sys.argv without session path + with mock.patch.object(sys, "argv", ["strands", "test", "query"]): + strands.main() + + # Verify agent was created without session manager + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args[1] + assert "session_manager" not in call_kwargs or call_kwargs.get("session_manager") is None + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + @mock.patch("strands_agents_builder.utils.session_utils.get_session_info") + def test_session_commands_in_interactive_mode( + self, + mock_get_info, + mock_create_manager, + mock_console_print, + mock_agent, + mock_bedrock, + mock_load_prompt, + mock_user_input, + mock_welcome_message, + mock_goodbye_message, + ): + """Test session-related commands in interactive mode""" + # Setup mocks for session commands + mock_user_input.side_effect = ["!session info", "exit"] + + # Mock session manager and info + mock_manager = mock.MagicMock() + mock_manager.session_id = "test-session-123" + mock_create_manager.return_value = mock_manager + mock_get_info.return_value = { + "session_id": "test-session-123", + "created_at": 1234567890, + "total_messages": 5, + "path": "/tmp/sessions/session_test-session-123", + } + + # Run with session path + with mock.patch.object(sys, "argv", ["strands", "--session-path", "/tmp/sessions"]): + strands.main() + + # Verify session info was retrieved + mock_get_info.assert_called_once_with("test-session-123", "/tmp/sessions") + + @mock.patch("strands_agents_builder.utils.session_utils.session_exists") + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + def test_resume_session_command(self, mock_create_manager, mock_session_exists, mock_bedrock, mock_load_prompt): + """Test --session-id command line argument for resuming sessions""" + # Setup mocks + mock_session_exists.return_value = True + mock_manager = mock.MagicMock() + mock_manager.session_id = "test_session" + mock_create_manager.return_value = mock_manager + + with mock.patch.object(strands, "Agent") as mock_agent_class: + mock_agent_instance = mock.MagicMock() + mock_agent_class.return_value = mock_agent_instance + + # Mock sys.argv with session ID + with mock.patch.object( + sys, + "argv", + ["strands", "--session-path", "/tmp/sessions", "--session-id", "test_session", "new", "query"], + ): + strands.main() + + # Verify session existence was checked + mock_session_exists.assert_called_once_with("test_session", "/tmp/sessions") + + # Verify session manager was created with the specified ID + mock_create_manager.assert_called_once_with("test_session", "/tmp/sessions") + + # Session resuming is now silent, no message printed + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + def test_session_path_argument_priority(self, mock_create_manager, mock_agent, mock_bedrock, mock_load_prompt): + """Test that --session-path argument takes priority over environment variable""" + # Setup environment variable + with mock.patch.dict(os.environ, {"STRANDS_SESSION_PATH": "/tmp/env_sessions"}): + # Mock sys.argv with different session path + with mock.patch.object(sys, "argv", ["strands", "--session-path", "/tmp/arg_sessions", "test", "query"]): + strands.main() + + # Verify command line argument was used, not environment variable + mock_create_manager.assert_called_once_with(None, "/tmp/arg_sessions") + + +class TestHelperFunctions: + """Test cases for helper functions extracted during refactoring""" + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + @mock.patch("strands_agents_builder.utils.session_utils.get_session_info") + @mock.patch("strands_agents_builder.utils.session_utils.list_available_sessions") + def test_list_sessions_command_function(self, mock_list_sessions, mock_get_info, mock_console_print): + """Test the list_sessions_command function directly""" + # Test with sessions available + mock_list_sessions.return_value = ["session1", "session2"] + mock_get_info.side_effect = [ + {"session_id": "session1", "created_at": 1234567890, "total_messages": 5}, + {"session_id": "session2", "created_at": 1234567891, "total_messages": 3}, + ] + + list_sessions_command("/tmp/sessions") + + mock_console_print.assert_any_call("[bold cyan]Available sessions:[/bold cyan]") + # Verify get_session_info was called for each session + mock_get_info.assert_any_call("session1", "/tmp/sessions") + mock_get_info.assert_any_call("session2", "/tmp/sessions") + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + @mock.patch("strands_agents_builder.utils.session_utils.session_exists") + def test_setup_session_management_function(self, mock_session_exists, mock_create_manager): + """Test the setup_session_management function directly""" + + # Test creating new session (no session_id provided) + mock_manager = mock.MagicMock() + mock_manager.session_id = "generated-session-123" + mock_create_manager.return_value = mock_manager + + result_manager, result_id, is_resuming = setup_session_management(None, "/tmp/test_sessions") + + assert result_manager == mock_manager + assert result_id == "generated-session-123" + assert is_resuming is False + mock_create_manager.assert_called_with(None, "/tmp/test_sessions") + + # Test resuming existing session + mock_session_exists.return_value = True + mock_manager.session_id = "existing-session" + + result_manager, result_id, is_resuming = setup_session_management("existing-session", "/tmp/test_sessions") + + assert result_manager == mock_manager + assert result_id == "existing-session" + assert is_resuming is True + + # Test creating session with provided ID (session doesn't exist) + mock_session_exists.return_value = False + + result_manager, result_id, is_resuming = setup_session_management("existing-session", "/tmp/test_sessions") + + assert result_manager == mock_manager + assert result_id == "existing-session" + assert is_resuming is False + + def test_execute_command_mode_function(self): + """Test the execute_command_mode function directly""" + # Create mock agent + mock_agent = mock.MagicMock() + + # Test command execution + strands.execute_command_mode(agent=mock_agent, query="test query", knowledge_base_id=None) + + # Verify agent was called with the query + mock_agent.assert_called_with("test query") + + @mock.patch("builtins.print") + @mock.patch("strands_agents_builder.utils.session_utils.get_session_info") + def test_handle_session_commands_function(self, mock_get_info, mock_print): + """Test the handle_session_commands function directly""" + + # Mock session info + mock_get_info.return_value = { + "session_id": "test-session", + "created_at": 1234567890, + "total_messages": 5, + "path": "/tmp/sessions/session_test-session", + } + + # Test session info command + result = handle_session_commands("session info", "test-session", "/tmp/sessions") + assert result is True + mock_get_info.assert_called_once_with("test-session", "/tmp/sessions") + + # Test non-session command + result = handle_session_commands("regular command", "test-session", "/tmp/sessions") + assert result is False + + @mock.patch("builtins.print") + def test_handle_shell_command_function(self, mock_print): + """Test the handle_shell_command function directly""" + # Create mock agent + mock_agent = mock.MagicMock() + + # Test shell command handling + strands.handle_shell_command(mock_agent, "ls -la", "!ls -la") + + # Verify shell command was executed + mock_agent.tool.shell.assert_called_with( + command="ls -la", user_message_override="!ls -la", non_interactive_mode=True + ) + + # Verify print was called with shell command + mock_print.assert_called_with("$ ls -la") diff --git a/tests/utils/test_session_utils.py b/tests/utils/test_session_utils.py new file mode 100644 index 0000000..9194793 --- /dev/null +++ b/tests/utils/test_session_utils.py @@ -0,0 +1,527 @@ +#!/usr/bin/env python3 +""" +Unit tests for the session_utils module using pytest +""" + +import tempfile +from pathlib import Path +from unittest import mock + +from strands_agents_builder.utils.session_utils import ( + create_session_manager, + display_agent_history, + generate_session_id, + get_session_info, + get_sessions_directory, + handle_session_commands, + list_available_sessions, + list_sessions_command, + session_exists, + setup_session_management, + validate_session_id, + validate_session_path, +) + + +class TestGetSessionsDirectory: + """Test cases for get_sessions_directory function""" + + def test_returns_none_when_no_path_provided(self): + """Test that function returns None when no session path is provided""" + result = get_sessions_directory(None) + assert result is None + + def test_returns_none_when_empty_path_provided(self): + """Test that function returns None when empty session path is provided""" + result = get_sessions_directory("") + assert result is None + + def test_returns_path_object_when_valid_path_provided(self): + """Test that function returns Path object when valid path is provided""" + test_path = "/tmp/test_sessions" + result = get_sessions_directory(test_path) + assert isinstance(result, Path) + # Use Path comparison instead of string comparison for cross-platform compatibility + assert result == Path(test_path) + + def test_creates_directory_when_create_is_true(self): + """Test that function creates directory when create=True""" + with tempfile.TemporaryDirectory() as temp_dir: + test_path = str(Path(temp_dir) / "new_sessions") + + # Directory shouldn't exist initially + assert not Path(test_path).exists() + + result = get_sessions_directory(test_path, create=True) + + # Directory should be created + assert result is not None + assert result.exists() + assert result.is_dir() + + def test_does_not_create_directory_when_create_is_false(self): + """Test that function doesn't create directory when create=False""" + with tempfile.TemporaryDirectory() as temp_dir: + test_path = str(Path(temp_dir) / "new_sessions") + + # Directory shouldn't exist initially + assert not Path(test_path).exists() + + result = get_sessions_directory(test_path, create=False) + + # Directory should not be created, but Path object returned + assert result is not None + assert not result.exists() + + +class TestListAvailableSessions: + """Test cases for list_available_sessions function""" + + def test_returns_empty_list_when_no_base_path(self): + """Test that function returns empty list when no base path is provided""" + result = list_available_sessions(None) + assert result == [] + + def test_returns_empty_list_when_directory_does_not_exist(self): + """Test that function returns empty list when directory doesn't exist""" + with tempfile.TemporaryDirectory() as temp_dir: + non_existent_path = str(Path(temp_dir) / "non_existent") + result = list_available_sessions(non_existent_path) + assert result == [] + + def test_returns_empty_list_when_directory_is_empty(self): + """Test that function returns empty list when directory is empty""" + with tempfile.TemporaryDirectory() as temp_dir: + result = list_available_sessions(temp_dir) + assert result == [] + + def test_returns_session_ids_when_present(self): + """Test that function returns session IDs when session directories exist""" + with tempfile.TemporaryDirectory() as temp_dir: + base_path = temp_dir + + # Create some test session directories (as expected by the function) + (Path(temp_dir) / "session_test-123-abc").mkdir() + (Path(temp_dir) / "session_test-456-def").mkdir() + (Path(temp_dir) / "not_a_session").mkdir() # Should be ignored + + result = list_available_sessions(base_path) + + # Should only return session IDs (without "session_" prefix) + assert len(result) == 2 + assert "test-123-abc" in result + assert "test-456-def" in result + assert "not_a_session" not in result + + def test_sorts_session_ids_alphabetically(self): + """Test that function returns session IDs sorted alphabetically""" + with tempfile.TemporaryDirectory() as temp_dir: + base_path = temp_dir + + # Create session directories in non-alphabetical order + (Path(temp_dir) / "session_zebra-session").mkdir() + (Path(temp_dir) / "session_alpha-session").mkdir() + (Path(temp_dir) / "session_beta-session").mkdir() + + result = list_available_sessions(base_path) + + assert result == ["alpha-session", "beta-session", "zebra-session"] + + +class TestCreateSessionManager: + """Test cases for create_session_manager function""" + + def test_returns_none_when_no_base_path(self): + """Test that function returns None when no base path is provided""" + result = create_session_manager(base_path=None) + assert result is None + + @mock.patch("strands_agents_builder.utils.session_utils.FileSessionManager") + def test_creates_session_manager_when_base_path_provided(self, mock_file_session_manager): + """Test that function creates FileSessionManager when base path is provided""" + mock_manager_instance = mock.MagicMock() + mock_file_session_manager.return_value = mock_manager_instance + + with tempfile.TemporaryDirectory() as temp_dir: + result = create_session_manager(session_id="test-session", base_path=temp_dir) + + # Verify FileSessionManager was called with correct parameters + mock_file_session_manager.assert_called_once_with(session_id="test-session", storage_dir=temp_dir) + assert result == mock_manager_instance + + @mock.patch("strands_agents_builder.utils.session_utils.FileSessionManager") + def test_generates_session_id_when_none_provided(self, mock_file_session_manager): + """Test that function generates session ID when none is provided""" + mock_manager_instance = mock.MagicMock() + mock_file_session_manager.return_value = mock_manager_instance + + with tempfile.TemporaryDirectory() as temp_dir: + create_session_manager(base_path=temp_dir) + + # Verify FileSessionManager was called with a generated session ID + mock_file_session_manager.assert_called_once() + call_args = mock_file_session_manager.call_args[1] + assert "session_id" in call_args + assert call_args["session_id"].startswith("strands-") + assert call_args["storage_dir"] == temp_dir + + def test_creates_directory_when_it_does_not_exist(self): + """Test that function creates directory when it doesn't exist""" + with tempfile.TemporaryDirectory() as temp_dir: + new_sessions_dir = str(Path(temp_dir) / "new_sessions_dir") + + # Ensure directory doesn't exist initially + assert not Path(new_sessions_dir).exists() + + with mock.patch("strands_agents_builder.utils.session_utils.FileSessionManager"): + create_session_manager(session_id="test", base_path=new_sessions_dir) + + # Verify directory was created + assert Path(new_sessions_dir).exists() + assert Path(new_sessions_dir).is_dir() + + +class TestGenerateSessionId: + """Test cases for generate_session_id function""" + + def test_generates_unique_session_ids(self): + """Test that function generates unique session IDs""" + id1 = generate_session_id() + id2 = generate_session_id() + + assert id1 != id2 + assert id1.startswith("strands-") + assert id2.startswith("strands-") + + def test_session_id_format(self): + """Test that session ID has expected format""" + session_id = generate_session_id() + + # Should be in format: strands-{timestamp}-{uuid} + parts = session_id.split("-") + assert len(parts) == 3 + assert parts[0] == "strands" + assert parts[1].isdigit() # timestamp + assert len(parts[2]) == 8 # short UUID + + +class TestSessionExists: + """Test cases for session_exists function""" + + def test_returns_false_when_no_base_path(self): + """Test that function returns False when no base path is provided""" + result = session_exists("test-session", None) + assert result is False + + def test_returns_false_when_session_does_not_exist(self): + """Test that function returns False when session doesn't exist""" + with tempfile.TemporaryDirectory() as temp_dir: + result = session_exists("non-existent", temp_dir) + assert result is False + + def test_returns_true_when_session_exists(self): + """Test that function returns True when session exists""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create session directory and file + session_dir = Path(temp_dir) / "session_test-session" + session_dir.mkdir() + (session_dir / "session.json").touch() + + result = session_exists("test-session", temp_dir) + assert result is True + + def test_returns_false_when_directory_exists_but_no_session_file(self): + """Test that function returns False when directory exists but no session.json""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create session directory but no session.json + session_dir = Path(temp_dir) / "session_test-session" + session_dir.mkdir() + + result = session_exists("test-session", temp_dir) + assert result is False + + +class TestGetSessionInfo: + """Test cases for get_session_info function""" + + def test_returns_none_when_no_base_path(self): + """Test that function returns None when no base path is provided""" + result = get_session_info("test-session", None) + assert result is None + + def test_returns_none_when_session_does_not_exist(self): + """Test that function returns None when session doesn't exist""" + with tempfile.TemporaryDirectory() as temp_dir: + result = get_session_info("non-existent", temp_dir) + assert result is None + + def test_returns_session_info_when_session_exists(self): + """Test that function returns session info when session exists""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create session directory and file + session_dir = Path(temp_dir) / "session_test-session" + session_dir.mkdir() + (session_dir / "session.json").touch() + + # Create some test messages + agents_dir = session_dir / "agents" / "agent1" / "messages" + agents_dir.mkdir(parents=True) + (agents_dir / "msg1.json").touch() + (agents_dir / "msg2.json").touch() + + result = get_session_info("test-session", temp_dir) + + assert result is not None + assert result["session_id"] == "test-session" + assert "created_at" in result + assert result["total_messages"] == 2 + assert result["path"] == str(session_dir) + + +class TestIntegration: + """Integration tests for session utilities""" + + def test_full_session_workflow(self): + """Test complete session workflow from creation to listing""" + with tempfile.TemporaryDirectory() as temp_dir: + base_path = temp_dir + + # Test create_session_manager creates directory and manager + with mock.patch("strands_agents_builder.utils.session_utils.FileSessionManager") as mock_fsm: + create_session_manager(session_id="test-session", base_path=base_path) + mock_fsm.assert_called_once() + + # Create some test session directories + (Path(temp_dir) / "session_test-session-1").mkdir() + (Path(temp_dir) / "session_test-session-2").mkdir() + + # Test list_available_sessions + sessions = list_available_sessions(base_path) + assert len(sessions) == 2 + assert "test-session-1" in sessions + assert "test-session-2" in sessions + + def test_session_lifecycle(self): + """Test complete session lifecycle""" + with tempfile.TemporaryDirectory() as temp_dir: + session_id = "test-lifecycle-session" + + # Initially session should not exist + assert not session_exists(session_id, temp_dir) + assert get_session_info(session_id, temp_dir) is None + + # Create session directory and file + session_dir = Path(temp_dir) / f"session_{session_id}" + session_dir.mkdir() + (session_dir / "session.json").touch() + + # Now session should exist + assert session_exists(session_id, temp_dir) + + # Should be able to get session info + info = get_session_info(session_id, temp_dir) + assert info is not None + assert info["session_id"] == session_id + + # Should appear in session list + sessions = list_available_sessions(temp_dir) + assert session_id in sessions + + +class TestSessionCommands: + """Test cases for session command functions""" + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + def test_list_sessions_command_no_sessions(self, mock_console_print): + """Test list-sessions command when no sessions exist""" + # Use a real temporary directory with no sessions + with tempfile.TemporaryDirectory() as temp_dir: + list_sessions_command(temp_dir) + + # Verify appropriate message was called + mock_console_print.assert_any_call("[yellow]No sessions found.[/yellow]") + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + def test_list_sessions_command_with_sessions(self, mock_console_print): + """Test list-sessions command when sessions exist""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create some test session directories + (Path(temp_dir) / "session_test-session-1").mkdir() + (Path(temp_dir) / "session_test-session-1" / "session.json").touch() + (Path(temp_dir) / "session_test-session-2").mkdir() + (Path(temp_dir) / "session_test-session-2" / "session.json").touch() + + list_sessions_command(temp_dir) + + # Verify sessions were listed + mock_console_print.assert_any_call("[bold cyan]Available sessions:[/bold cyan]") + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + def test_list_sessions_command_no_base_path(self, mock_console_print): + """Test list-sessions command when no session path is configured""" + list_sessions_command(None) + + # Verify appropriate error message was called + mock_console_print.assert_called_with( + "[red]Error: Session management not enabled. Use --session-path or " + "set STRANDS_SESSION_PATH environment variable.[/red]" + ) + + @mock.patch("builtins.print") + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + def test_display_agent_history_function(self, mock_console_print, mock_print): + """Test the display_agent_history function directly""" + + # Test with few messages (no truncation) + mock_agent = mock.MagicMock() + mock_agent.messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there! How can I help you?"}]}, + ] + + display_agent_history(mock_agent, "test-session-123") + + # Verify console.print was called for the header panel + mock_console_print.assert_called() + # Verify regular print was called for messages (with colorama formatting) + mock_print.assert_any_call("\x1b[32m~ \x1b[0mHello") + mock_print.assert_any_call() # Empty line after user message + + # Test with many messages (should show truncation indicator) + mock_print.reset_mock() + mock_console_print.reset_mock() + mock_agent.messages = [{"role": "user", "content": [{"text": f"Message {i}"}]} for i in range(12)] + + display_agent_history(mock_agent, "test-session-456") + + # Should show console panel with truncation info in subtitle + mock_console_print.assert_called() + + # Test with no messages + mock_print.reset_mock() + mock_agent.messages = [] + display_agent_history(mock_agent, "empty-session") + + # Should not print anything for empty session + mock_print.assert_not_called() + + # Test with None messages + mock_print.reset_mock() + mock_agent.messages = None + display_agent_history(mock_agent, "none-session") + + # Should not print anything for None messages + mock_print.assert_not_called() + + @mock.patch("strands_agents_builder.utils.session_utils.create_session_manager") + @mock.patch("strands_agents_builder.utils.session_utils.session_exists") + def test_setup_session_management_function(self, mock_session_exists, mock_create_manager): + """Test the setup_session_management function directly""" + + # Test creating new session (no session_id provided) + mock_manager = mock.MagicMock() + mock_manager.session_id = "generated-session-123" + mock_create_manager.return_value = mock_manager + + result_manager, result_id, is_resuming = setup_session_management(None, "/tmp/test_sessions") + + assert result_manager == mock_manager + assert result_id == "generated-session-123" + assert is_resuming is False + mock_create_manager.assert_called_with(None, "/tmp/test_sessions") + + # Test resuming existing session + mock_session_exists.return_value = True + mock_manager.session_id = "existing-session" + + result_manager, result_id, is_resuming = setup_session_management("existing-session", "/tmp/test_sessions") + + assert result_manager == mock_manager + assert result_id == "existing-session" + assert is_resuming is True + + # Test creating session with provided ID (session doesn't exist) + mock_session_exists.return_value = False + + result_manager, result_id, is_resuming = setup_session_management("existing-session", "/tmp/test_sessions") + + assert result_manager == mock_manager + assert result_id == "existing-session" + assert is_resuming is False + + @mock.patch("strands_agents_builder.utils.session_utils.console.print") + @mock.patch("strands_agents_builder.utils.session_utils.get_session_info") + def test_handle_session_commands_function(self, mock_get_info, mock_console_print): + """Test the handle_session_commands function directly""" + + # Mock session info + mock_get_info.return_value = { + "session_id": "test-session", + "created_at": 1234567890, + "total_messages": 5, + "path": "/tmp/sessions/session_test-session", + } + + # Test session info command + result = handle_session_commands("session info", "test-session", "/tmp/sessions") + assert result is True + mock_get_info.assert_called_once_with("test-session", "/tmp/sessions") + + # Test non-session command + result = handle_session_commands("regular command", "test-session", "/tmp/sessions") + assert result is False + + +class TestValidationFunctions: + """Test cases for validation functions""" + + def test_validate_session_id_valid_ids(self): + """Test that valid session IDs pass validation""" + valid_ids = [ + "strands-1234567890-abcd1234", + "test-session", + "session123", + "my_session", + "session-with-dashes", + ] + + for session_id in valid_ids: + assert validate_session_id(session_id), f"Should be valid: {session_id}" + + def test_validate_session_id_invalid_ids(self): + """Test that invalid session IDs fail validation""" + invalid_ids = [ + "", # Empty + "session/with/slash", # Path separator + "session\\with\\backslash", # Windows path separator + "session..with..dots", # Path traversal + ".hidden-session", # Hidden file + "session\0with\0null", # Null character + "x" * 256, # Too long + ] + + for session_id in invalid_ids: + assert not validate_session_id(session_id), f"Should be invalid: {session_id}" + + def test_validate_session_path_valid_paths(self): + """Test that valid session paths pass validation""" + valid_paths = [ + "/tmp/sessions", + "./sessions", + "../sessions", + "sessions", + "/home/user/strands/sessions", + ] + + for path in valid_paths: + assert validate_session_path(path), f"Should be valid: {path}" + + def test_validate_session_path_invalid_paths(self): + """Test that invalid session paths fail validation""" + invalid_paths = [ + "", # Empty + "x" * 5000, # Too long + ] + + for path in invalid_paths: + assert not validate_session_path(path), f"Should be invalid: {path}"