diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 5710d28037c77..e524d384063fd 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -1,9 +1,16 @@ """Middleware plugins for agents.""" +from .anthropic_tools import ( + FilesystemClaudeMemoryMiddleware, + FilesystemClaudeTextEditorMiddleware, + StateClaudeMemoryMiddleware, + StateClaudeTextEditorMiddleware, +) from .context_editing import ( ClearToolUsesEdit, ContextEditingMiddleware, ) +from .file_search import FilesystemFileSearchMiddleware, StateFileSearchMiddleware from .human_in_the_loop import ( HumanInTheLoopMiddleware, InterruptOnConfig, @@ -36,6 +43,9 @@ "AgentState", "ClearToolUsesEdit", "ContextEditingMiddleware", + "FilesystemClaudeMemoryMiddleware", + "FilesystemClaudeTextEditorMiddleware", + "FilesystemFileSearchMiddleware", "HumanInTheLoopMiddleware", "InterruptOnConfig", "LLMToolEmulator", @@ -46,6 +56,9 @@ "ModelResponse", "PIIDetectionError", "PIIMiddleware", + "StateClaudeMemoryMiddleware", + "StateClaudeTextEditorMiddleware", + "StateFileSearchMiddleware", "SummarizationMiddleware", "TodoListMiddleware", "ToolCallLimitMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/anthropic_tools.py b/libs/langchain_v1/langchain/agents/middleware/anthropic_tools.py new file mode 100644 index 0000000000000..3ee9deca9f4d5 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/anthropic_tools.py @@ -0,0 +1,1032 @@ +"""Anthropic text editor and memory tool middleware. + +This module provides client-side implementations of Anthropic's text editor and +memory tools using schema-less tool definitions and tool call interception. +""" + +from __future__ import annotations + +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, cast + +from langchain_core.messages import ToolMessage +from langgraph.types import Command +from typing_extensions import NotRequired, TypedDict + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ModelRequest, + ModelResponse, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from langchain.tools.tool_node import ToolCallRequest + +# Tool type constants +TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728" +TEXT_EDITOR_TOOL_NAME = "str_replace_based_edit_tool" +MEMORY_TOOL_TYPE = "memory_20250818" +MEMORY_TOOL_NAME = "memory" + +MEMORY_SYSTEM_PROMPT = """IMPORTANT: ALWAYS VIEW YOUR MEMORY DIRECTORY BEFORE \ +DOING ANYTHING ELSE. +MEMORY PROTOCOL: +1. Use the `view` command of your `memory` tool to check for earlier progress. +2. ... (work on the task) ... + - As you make progress, record status / progress / thoughts etc in your memory. +ASSUME INTERRUPTION: Your context window might be reset at any moment, so you risk \ +losing any progress that is not recorded in your memory directory.""" + + +class FileData(TypedDict): + """Data structure for storing file contents.""" + + content: list[str] + """Lines of the file.""" + + created_at: str + """ISO 8601 timestamp of file creation.""" + + modified_at: str + """ISO 8601 timestamp of last modification.""" + + +def files_reducer( + left: dict[str, FileData] | None, right: dict[str, FileData | None] +) -> dict[str, FileData]: + """Custom reducer that merges file updates. + + Args: + left: Existing files dict. + right: New files dict to merge (None values delete files). + + Returns: + Merged dict where right overwrites left for matching keys. + """ + if left is None: + # Filter out None values when initializing + return {k: v for k, v in right.items() if v is not None} + + # Merge, filtering out None values (deletions) + result = {**left} + for k, v in right.items(): + if v is None: + result.pop(k, None) + else: + result[k] = v + return result + + +class AnthropicToolsState(AgentState): + """State schema for Anthropic text editor and memory tools.""" + + text_editor_files: NotRequired[Annotated[dict[str, FileData], files_reducer]] + """Virtual file system for text editor tools.""" + + memory_files: NotRequired[Annotated[dict[str, FileData], files_reducer]] + """Virtual file system for memory tools.""" + + +def _validate_path(path: str, *, allowed_prefixes: Sequence[str] | None = None) -> str: + """Validate and normalize file path for security. + + Args: + path: The path to validate. + allowed_prefixes: Optional list of allowed path prefixes. + + Returns: + Normalized canonical path. + + Raises: + ValueError: If path contains traversal sequences or violates prefix rules. + """ + # Reject paths with traversal attempts + if ".." in path or path.startswith("~"): + msg = f"Path traversal not allowed: {path}" + raise ValueError(msg) + + # Normalize path (resolve ., //, etc.) + normalized = os.path.normpath(path) + + # Convert to forward slashes for consistency + normalized = normalized.replace("\\", "/") + + # Ensure path starts with / + if not normalized.startswith("/"): + normalized = f"/{normalized}" + + # Check allowed prefixes if specified + if allowed_prefixes is not None and not any( + normalized.startswith(prefix) for prefix in allowed_prefixes + ): + msg = f"Path must start with one of {allowed_prefixes}: {path}" + raise ValueError(msg) + + return normalized + + +def _list_directory(files: dict[str, FileData], path: str) -> list[str]: + """List files in a directory. + + Args: + files: Files dict. + path: Normalized directory path. + + Returns: + Sorted list of file paths in the directory. + """ + # Ensure path ends with / for directory matching + dir_path = path if path.endswith("/") else f"{path}/" + + matching_files = [] + for file_path in files: + if file_path.startswith(dir_path): + # Get relative path from directory + relative = file_path[len(dir_path) :] + # Only include direct children (no subdirectories) + if "/" not in relative: + matching_files.append(file_path) + + return sorted(matching_files) + + +class _StateClaudeFileToolMiddleware(AgentMiddleware): + """Base class for state-based file tool middleware (internal).""" + + state_schema = AnthropicToolsState + + def __init__( + self, + *, + tool_type: str, + tool_name: str, + state_key: str, + allowed_path_prefixes: Sequence[str] | None = None, + system_prompt: str | None = None, + ) -> None: + """Initialize the middleware. + + Args: + tool_type: Tool type identifier. + tool_name: Tool name. + state_key: State key for file storage. + allowed_path_prefixes: Optional list of allowed path prefixes. + system_prompt: Optional system prompt to inject. + """ + self.tool_type = tool_type + self.tool_name = tool_name + self.state_key = state_key + self.allowed_prefixes = allowed_path_prefixes + self.system_prompt = system_prompt + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + """Inject tool and optional system prompt.""" + # Add tool + tools = list(request.tools or []) + tools.append( + { + "type": self.tool_type, + "name": self.tool_name, + } + ) + request.tools = tools + + # Inject system prompt if provided + if self.system_prompt: + request.system_prompt = ( + request.system_prompt + "\n\n" + self.system_prompt + if request.system_prompt + else self.system_prompt + ) + + return handler(request) + + def wrap_tool_call( + self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command] + ) -> ToolMessage | Command: + """Intercept tool calls.""" + tool_call = request.tool_call + tool_name = tool_call.get("name") + + if tool_name != self.tool_name: + return handler(request) + + # Handle tool call + try: + args = tool_call.get("args", {}) + command = args.get("command") + state = request.state + + if command == "view": + return self._handle_view(args, state, tool_call["id"]) + if command == "create": + return self._handle_create(args, state, tool_call["id"]) + if command == "str_replace": + return self._handle_str_replace(args, state, tool_call["id"]) + if command == "insert": + return self._handle_insert(args, state, tool_call["id"]) + if command == "delete": + return self._handle_delete(args, state, tool_call["id"]) + if command == "rename": + return self._handle_rename(args, state, tool_call["id"]) + + msg = f"Unknown command: {command}" + return ToolMessage( + content=msg, + tool_call_id=tool_call["id"], + name=tool_name, + status="error", + ) + except (ValueError, FileNotFoundError) as e: + return ToolMessage( + content=str(e), + tool_call_id=tool_call["id"], + name=tool_name, + status="error", + ) + + def _handle_view( + self, args: dict, state: AnthropicToolsState, tool_call_id: str | None + ) -> Command: + """Handle view command.""" + path = args["path"] + normalized_path = _validate_path(path, allowed_prefixes=self.allowed_prefixes) + + files = cast("dict[str, Any]", state.get(self.state_key, {})) + file_data = files.get(normalized_path) + + if file_data is None: + # Try directory listing + matching = _list_directory(files, normalized_path) + + if matching: + content = "\n".join(matching) + return Command( + update={ + "messages": [ + ToolMessage( + content=content, + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + msg = f"File not found: {path}" + raise FileNotFoundError(msg) + + # Format file content with line numbers + lines_content = file_data["content"] + formatted_lines = [f"{i + 1}|{line}" for i, line in enumerate(lines_content)] + content = "\n".join(formatted_lines) + + return Command( + update={ + "messages": [ + ToolMessage( + content=content, + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + def _handle_create( + self, args: dict, state: AnthropicToolsState, tool_call_id: str | None + ) -> Command: + """Handle create command.""" + path = args["path"] + file_text = args["file_text"] + + normalized_path = _validate_path(path, allowed_prefixes=self.allowed_prefixes) + + # Get existing files + files = cast("dict[str, Any]", state.get(self.state_key, {})) + existing = files.get(normalized_path) + + # Create file data + now = datetime.now(timezone.utc).isoformat() + created_at = existing["created_at"] if existing else now + + content_lines = file_text.split("\n") + + return Command( + update={ + self.state_key: { + normalized_path: { + "content": content_lines, + "created_at": created_at, + "modified_at": now, + } + }, + "messages": [ + ToolMessage( + content=f"File created: {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ], + } + ) + + def _handle_str_replace( + self, args: dict, state: AnthropicToolsState, tool_call_id: str | None + ) -> Command: + """Handle str_replace command.""" + path = args["path"] + old_str = args["old_str"] + new_str = args.get("new_str", "") + + normalized_path = _validate_path(path, allowed_prefixes=self.allowed_prefixes) + + # Read file + files = cast("dict[str, Any]", state.get(self.state_key, {})) + file_data = files.get(normalized_path) + if file_data is None: + msg = f"File not found: {path}" + raise FileNotFoundError(msg) + + lines_content = file_data["content"] + content = "\n".join(lines_content) + + # Replace string + if old_str not in content: + msg = f"String not found in file: {old_str}" + raise ValueError(msg) + + new_content = content.replace(old_str, new_str, 1) + new_lines = new_content.split("\n") + + # Update file + now = datetime.now(timezone.utc).isoformat() + + return Command( + update={ + self.state_key: { + normalized_path: { + "content": new_lines, + "created_at": file_data["created_at"], + "modified_at": now, + } + }, + "messages": [ + ToolMessage( + content=f"String replaced in {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ], + } + ) + + def _handle_insert( + self, args: dict, state: AnthropicToolsState, tool_call_id: str | None + ) -> Command: + """Handle insert command.""" + path = args["path"] + insert_line = args["insert_line"] + text_to_insert = args["new_str"] + + normalized_path = _validate_path(path, allowed_prefixes=self.allowed_prefixes) + + # Read file + files = cast("dict[str, Any]", state.get(self.state_key, {})) + file_data = files.get(normalized_path) + if file_data is None: + msg = f"File not found: {path}" + raise FileNotFoundError(msg) + + lines_content = file_data["content"] + new_lines = text_to_insert.split("\n") + + # Insert after insert_line (0-indexed) + updated_lines = lines_content[:insert_line] + new_lines + lines_content[insert_line:] + + # Update file + now = datetime.now(timezone.utc).isoformat() + + return Command( + update={ + self.state_key: { + normalized_path: { + "content": updated_lines, + "created_at": file_data["created_at"], + "modified_at": now, + } + }, + "messages": [ + ToolMessage( + content=f"Text inserted in {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ], + } + ) + + def _handle_delete( + self, + args: dict, + state: AnthropicToolsState, # noqa: ARG002 + tool_call_id: str | None, + ) -> Command: + """Handle delete command.""" + path = args["path"] + + normalized_path = _validate_path(path, allowed_prefixes=self.allowed_prefixes) + + return Command( + update={ + self.state_key: {normalized_path: None}, + "messages": [ + ToolMessage( + content=f"File deleted: {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ], + } + ) + + def _handle_rename( + self, args: dict, state: AnthropicToolsState, tool_call_id: str | None + ) -> Command: + """Handle rename command.""" + old_path = args["old_path"] + new_path = args["new_path"] + + normalized_old = _validate_path(old_path, allowed_prefixes=self.allowed_prefixes) + normalized_new = _validate_path(new_path, allowed_prefixes=self.allowed_prefixes) + + # Read file + files = cast("dict[str, Any]", state.get(self.state_key, {})) + file_data = files.get(normalized_old) + if file_data is None: + msg = f"File not found: {old_path}" + raise ValueError(msg) + + # Update timestamp + now = datetime.now(timezone.utc).isoformat() + file_data_copy = file_data.copy() + file_data_copy["modified_at"] = now + + return Command( + update={ + self.state_key: { + normalized_old: None, + normalized_new: file_data_copy, + }, + "messages": [ + ToolMessage( + content=f"File renamed: {old_path} -> {new_path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ], + } + ) + + +class StateClaudeTextEditorMiddleware(_StateClaudeFileToolMiddleware): + """State-based text editor tool middleware. + + Provides Anthropic's text_editor tool using LangGraph state for storage. + Files persist for the conversation thread. + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import StateTextEditorToolMiddleware + + agent = create_agent( + model=model, + tools=[], + middleware=[StateTextEditorToolMiddleware()], + ) + ``` + """ + + def __init__( + self, + *, + allowed_path_prefixes: Sequence[str] | None = None, + ) -> None: + """Initialize the text editor middleware. + + Args: + allowed_path_prefixes: Optional list of allowed path prefixes. + If specified, only paths starting with these prefixes are allowed. + """ + super().__init__( + tool_type=TEXT_EDITOR_TOOL_TYPE, + tool_name=TEXT_EDITOR_TOOL_NAME, + state_key="text_editor_files", + allowed_path_prefixes=allowed_path_prefixes, + ) + + +class StateClaudeMemoryMiddleware(_StateClaudeFileToolMiddleware): + """State-based memory tool middleware. + + Provides Anthropic's memory tool using LangGraph state for storage. + Files persist for the conversation thread. Enforces /memories prefix + and injects Anthropic's recommended system prompt. + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import StateMemoryToolMiddleware + + agent = create_agent( + model=model, + tools=[], + middleware=[StateMemoryToolMiddleware()], + ) + ``` + """ + + def __init__( + self, + *, + allowed_path_prefixes: Sequence[str] | None = None, + system_prompt: str = MEMORY_SYSTEM_PROMPT, + ) -> None: + """Initialize the memory middleware. + + Args: + allowed_path_prefixes: Optional list of allowed path prefixes. + Defaults to ["/memories"]. + system_prompt: System prompt to inject. Defaults to Anthropic's + recommended memory prompt. + """ + super().__init__( + tool_type=MEMORY_TOOL_TYPE, + tool_name=MEMORY_TOOL_NAME, + state_key="memory_files", + allowed_path_prefixes=allowed_path_prefixes or ["/memories"], + system_prompt=system_prompt, + ) + + +class _FilesystemClaudeFileToolMiddleware(AgentMiddleware): + """Base class for filesystem-based file tool middleware (internal).""" + + def __init__( + self, + *, + tool_type: str, + tool_name: str, + root_path: str, + allowed_prefixes: list[str] | None = None, + max_file_size_mb: int = 10, + system_prompt: str | None = None, + ) -> None: + """Initialize the middleware. + + Args: + tool_type: Tool type identifier. + tool_name: Tool name. + root_path: Root directory for file operations. + allowed_prefixes: Optional list of allowed virtual path prefixes. + max_file_size_mb: Maximum file size in MB. + system_prompt: Optional system prompt to inject. + """ + self.tool_type = tool_type + self.tool_name = tool_name + self.root_path = Path(root_path).resolve() + self.allowed_prefixes = allowed_prefixes or ["/"] + self.max_file_size_bytes = max_file_size_mb * 1024 * 1024 + self.system_prompt = system_prompt + + # Create root directory if it doesn't exist + self.root_path.mkdir(parents=True, exist_ok=True) + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + """Inject tool and optional system prompt.""" + # Add tool + tools = list(request.tools or []) + tools.append( + { + "type": self.tool_type, + "name": self.tool_name, + } + ) + request.tools = tools + + # Inject system prompt if provided + if self.system_prompt: + request.system_prompt = ( + request.system_prompt + "\n\n" + self.system_prompt + if request.system_prompt + else self.system_prompt + ) + + return handler(request) + + def wrap_tool_call( + self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command] + ) -> ToolMessage | Command: + """Intercept tool calls.""" + tool_call = request.tool_call + tool_name = tool_call.get("name") + + if tool_name != self.tool_name: + return handler(request) + + # Handle tool call + try: + args = tool_call.get("args", {}) + command = args.get("command") + + if command == "view": + return self._handle_view(args, tool_call["id"]) + if command == "create": + return self._handle_create(args, tool_call["id"]) + if command == "str_replace": + return self._handle_str_replace(args, tool_call["id"]) + if command == "insert": + return self._handle_insert(args, tool_call["id"]) + if command == "delete": + return self._handle_delete(args, tool_call["id"]) + if command == "rename": + return self._handle_rename(args, tool_call["id"]) + + msg = f"Unknown command: {command}" + return ToolMessage( + content=msg, + tool_call_id=tool_call["id"], + name=tool_name, + status="error", + ) + except (ValueError, FileNotFoundError) as e: + return ToolMessage( + content=str(e), + tool_call_id=tool_call["id"], + name=tool_name, + status="error", + ) + + def _validate_and_resolve_path(self, path: str) -> Path: + """Validate and resolve a virtual path to filesystem path. + + Args: + path: Virtual path (e.g., /file.txt or /src/main.py). + + Returns: + Resolved absolute filesystem path within root_path. + + Raises: + ValueError: If path contains traversal attempts, escapes root directory, + or violates allowed_prefixes restrictions. + """ + # Normalize path + if not path.startswith("/"): + path = "/" + path + + # Check for path traversal + if ".." in path or "~" in path: + msg = "Path traversal not allowed" + raise ValueError(msg) + + # Convert virtual path to filesystem path + # Remove leading / and resolve relative to root + relative = path.lstrip("/") + full_path = (self.root_path / relative).resolve() + + # Ensure path is within root + try: + full_path.relative_to(self.root_path) + except ValueError: + msg = f"Path outside root directory: {path}" + raise ValueError(msg) from None + + # Check allowed prefixes + virtual_path = "/" + str(full_path.relative_to(self.root_path)) + if self.allowed_prefixes: + allowed = any( + virtual_path.startswith(prefix) or virtual_path == prefix.rstrip("/") + for prefix in self.allowed_prefixes + ) + if not allowed: + msg = f"Path must start with one of: {self.allowed_prefixes}" + raise ValueError(msg) + + return full_path + + def _handle_view(self, args: dict, tool_call_id: str | None) -> Command: + """Handle view command.""" + path = args["path"] + full_path = self._validate_and_resolve_path(path) + + if not full_path.exists() or not full_path.is_file(): + msg = f"File not found: {path}" + raise FileNotFoundError(msg) + + # Check file size + if full_path.stat().st_size > self.max_file_size_bytes: + msg = f"File too large: {path} exceeds {self.max_file_size_bytes / 1024 / 1024}MB" + raise ValueError(msg) + + # Read file + try: + content = full_path.read_text() + except UnicodeDecodeError as e: + msg = f"Cannot decode file {path}: {e}" + raise ValueError(msg) from e + + # Format with line numbers + lines = content.split("\n") + # Remove trailing newline's empty string if present + if lines and lines[-1] == "": + lines = lines[:-1] + formatted_lines = [f"{i + 1}|{line}" for i, line in enumerate(lines)] + formatted_content = "\n".join(formatted_lines) + + return Command( + update={ + "messages": [ + ToolMessage( + content=formatted_content, + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + def _handle_create(self, args: dict, tool_call_id: str | None) -> Command: + """Handle create command.""" + path = args["path"] + file_text = args["file_text"] + + full_path = self._validate_and_resolve_path(path) + + # Create parent directories + full_path.parent.mkdir(parents=True, exist_ok=True) + + # Write file + full_path.write_text(file_text + "\n") + + return Command( + update={ + "messages": [ + ToolMessage( + content=f"File created: {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + def _handle_str_replace(self, args: dict, tool_call_id: str | None) -> Command: + """Handle str_replace command.""" + path = args["path"] + old_str = args["old_str"] + new_str = args.get("new_str", "") + + full_path = self._validate_and_resolve_path(path) + + if not full_path.exists(): + msg = f"File not found: {path}" + raise FileNotFoundError(msg) + + # Read file + content = full_path.read_text() + + # Replace string + if old_str not in content: + msg = f"String not found in file: {old_str}" + raise ValueError(msg) + + new_content = content.replace(old_str, new_str, 1) + + # Write back + full_path.write_text(new_content) + + return Command( + update={ + "messages": [ + ToolMessage( + content=f"String replaced in {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + def _handle_insert(self, args: dict, tool_call_id: str | None) -> Command: + """Handle insert command.""" + path = args["path"] + insert_line = args["insert_line"] + text_to_insert = args["new_str"] + + full_path = self._validate_and_resolve_path(path) + + if not full_path.exists(): + msg = f"File not found: {path}" + raise FileNotFoundError(msg) + + # Read file + content = full_path.read_text() + lines = content.split("\n") + # Handle trailing newline + if lines and lines[-1] == "": + lines = lines[:-1] + had_trailing_newline = True + else: + had_trailing_newline = False + + new_lines = text_to_insert.split("\n") + + # Insert after insert_line (0-indexed) + updated_lines = lines[:insert_line] + new_lines + lines[insert_line:] + + # Write back + new_content = "\n".join(updated_lines) + if had_trailing_newline: + new_content += "\n" + full_path.write_text(new_content) + + return Command( + update={ + "messages": [ + ToolMessage( + content=f"Text inserted in {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + def _handle_delete(self, args: dict, tool_call_id: str | None) -> Command: + """Handle delete command.""" + import shutil + + path = args["path"] + full_path = self._validate_and_resolve_path(path) + + if full_path.is_file(): + full_path.unlink() + elif full_path.is_dir(): + shutil.rmtree(full_path) + # If doesn't exist, silently succeed + + return Command( + update={ + "messages": [ + ToolMessage( + content=f"File deleted: {path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + def _handle_rename(self, args: dict, tool_call_id: str | None) -> Command: + """Handle rename command.""" + old_path = args["old_path"] + new_path = args["new_path"] + + old_full = self._validate_and_resolve_path(old_path) + new_full = self._validate_and_resolve_path(new_path) + + if not old_full.exists(): + msg = f"File not found: {old_path}" + raise ValueError(msg) + + # Create parent directory for new path + new_full.parent.mkdir(parents=True, exist_ok=True) + + # Rename + old_full.rename(new_full) + + return Command( + update={ + "messages": [ + ToolMessage( + content=f"File renamed: {old_path} -> {new_path}", + tool_call_id=tool_call_id, + name=self.tool_name, + ) + ] + } + ) + + +class FilesystemClaudeTextEditorMiddleware(_FilesystemClaudeFileToolMiddleware): + """Filesystem-based text editor tool middleware. + + Provides Anthropic's text_editor tool using local filesystem for storage. + User handles persistence via volumes, git, or other mechanisms. + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import FilesystemTextEditorToolMiddleware + + agent = create_agent( + model=model, + tools=[], + middleware=[FilesystemTextEditorToolMiddleware(root_path="/workspace")], + ) + ``` + """ + + def __init__( + self, + *, + root_path: str, + allowed_prefixes: list[str] | None = None, + max_file_size_mb: int = 10, + ) -> None: + """Initialize the text editor middleware. + + Args: + root_path: Root directory for file operations. + allowed_prefixes: Optional list of allowed virtual path prefixes (default: ["/"]). + max_file_size_mb: Maximum file size in MB (default: 10). + """ + super().__init__( + tool_type=TEXT_EDITOR_TOOL_TYPE, + tool_name=TEXT_EDITOR_TOOL_NAME, + root_path=root_path, + allowed_prefixes=allowed_prefixes, + max_file_size_mb=max_file_size_mb, + ) + + +class FilesystemClaudeMemoryMiddleware(_FilesystemClaudeFileToolMiddleware): + """Filesystem-based memory tool middleware. + + Provides Anthropic's memory tool using local filesystem for storage. + User handles persistence via volumes, git, or other mechanisms. + Enforces /memories prefix and injects Anthropic's recommended system prompt. + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import FilesystemMemoryToolMiddleware + + agent = create_agent( + model=model, + tools=[], + middleware=[FilesystemMemoryToolMiddleware(root_path="/workspace")], + ) + ``` + """ + + def __init__( + self, + *, + root_path: str, + allowed_prefixes: list[str] | None = None, + max_file_size_mb: int = 10, + system_prompt: str = MEMORY_SYSTEM_PROMPT, + ) -> None: + """Initialize the memory middleware. + + Args: + root_path: Root directory for file operations. + allowed_prefixes: Optional list of allowed virtual path prefixes. + Defaults to ["/memories"]. + max_file_size_mb: Maximum file size in MB (default: 10). + system_prompt: System prompt to inject. Defaults to Anthropic's + recommended memory prompt. + """ + super().__init__( + tool_type=MEMORY_TOOL_TYPE, + tool_name=MEMORY_TOOL_NAME, + root_path=root_path, + allowed_prefixes=allowed_prefixes or ["/memories"], + max_file_size_mb=max_file_size_mb, + system_prompt=system_prompt, + ) + + +__all__ = [ + "AnthropicToolsState", + "FileData", + "FilesystemClaudeMemoryMiddleware", + "FilesystemClaudeTextEditorMiddleware", + "StateClaudeMemoryMiddleware", + "StateClaudeTextEditorMiddleware", +] diff --git a/libs/langchain_v1/langchain/agents/middleware/file_search.py b/libs/langchain_v1/langchain/agents/middleware/file_search.py new file mode 100644 index 0000000000000..9698fc3bc95f2 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/file_search.py @@ -0,0 +1,588 @@ +"""File search middleware for Anthropic text editor and memory tools. + +This module provides Glob and Grep search tools that operate on files stored +in state or filesystem. +""" + +from __future__ import annotations + +import fnmatch +import json +import re +import subprocess +from contextlib import suppress +from datetime import datetime, timezone +from pathlib import Path, PurePosixPath +from typing import Annotated, Any, Literal, cast + +from langchain_core.tools import InjectedToolArg, tool + +from langchain.agents.middleware.anthropic_tools import AnthropicToolsState +from langchain.agents.middleware.types import AgentMiddleware + + +def _expand_include_patterns(pattern: str) -> list[str] | None: + """Expand brace patterns like ``*.{py,pyi}`` into a list of globs.""" + if "}" in pattern and "{" not in pattern: + return None + + expanded: list[str] = [] + + def _expand(current: str) -> None: + start = current.find("{") + if start == -1: + expanded.append(current) + return + + end = current.find("}", start) + if end == -1: + raise ValueError + + prefix = current[:start] + suffix = current[end + 1 :] + inner = current[start + 1 : end] + if not inner: + raise ValueError + + for option in inner.split(","): + _expand(prefix + option + suffix) + + try: + _expand(pattern) + except ValueError: + return None + + return expanded + + +def _is_valid_include_pattern(pattern: str) -> bool: + """Validate glob pattern used for include filters.""" + if not pattern: + return False + + if any(char in pattern for char in ("\x00", "\n", "\r")): + return False + + expanded = _expand_include_patterns(pattern) + if expanded is None: + return False + + try: + for candidate in expanded: + re.compile(fnmatch.translate(candidate)) + except re.error: + return False + + return True + + +def _match_include_pattern(basename: str, pattern: str) -> bool: + """Return True if the basename matches the include pattern.""" + expanded = _expand_include_patterns(pattern) + if not expanded: + return False + + return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded) + + +class StateFileSearchMiddleware(AgentMiddleware): + """Provides Glob and Grep search over state-based files. + + This middleware adds two tools that search through virtual files in state: + - Glob: Fast file pattern matching by file path + - Grep: Fast content search using regular expressions + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import ( + StateTextEditorToolMiddleware, + StateFileSearchMiddleware, + ) + + agent = create_agent( + model=model, + tools=[], + middleware=[ + StateTextEditorToolMiddleware(), + StateFileSearchMiddleware(), + ], + ) + ``` + """ + + state_schema = AnthropicToolsState + + def __init__( + self, + *, + state_key: str = "text_editor_files", + ) -> None: + """Initialize the search middleware. + + Args: + state_key: State key to search (default: "text_editor_files"). + Use "memory_files" to search memory tool files. + """ + self.state_key = state_key + + # Create tool instances + @tool + def glob_search( # noqa: D417 + pattern: str, + path: str = "/", + state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment] + ) -> str: + """Fast file pattern matching tool that works with any codebase size. + + Supports glob patterns like **/*.js or src/**/*.ts. + Returns matching file paths sorted by modification time. + Use this tool when you need to find files by name patterns. + + Args: + pattern: The glob pattern to match files against. + path: The directory to search in. If not specified, searches from root. + + Returns: + Newline-separated list of matching file paths, sorted by modification + time (most recently modified first). Returns "No files found" if no + matches. + """ + # Normalize base path + base_path = path if path.startswith("/") else "/" + path + + # Get files from state + files = cast("dict[str, Any]", state.get(self.state_key, {})) + + # Match files + matches = [] + for file_path, file_data in files.items(): + if file_path.startswith(base_path): + # Get relative path from base + if base_path == "/": + relative = file_path[1:] # Remove leading / + elif file_path == base_path: + relative = Path(file_path).name + elif file_path.startswith(base_path + "/"): + relative = file_path[len(base_path) + 1 :] + else: + continue + + # Match against pattern + # Handle ** pattern which requires special care + # PurePosixPath.match doesn't match single-level paths against **/pattern + is_match = PurePosixPath(relative).match(pattern) + if not is_match and pattern.startswith("**/"): + # Also try matching without the **/ prefix for files in base dir + is_match = PurePosixPath(relative).match(pattern[3:]) + + if is_match: + matches.append((file_path, file_data["modified_at"])) + + if not matches: + return "No files found" + + # Sort by modification time + matches.sort(key=lambda x: x[1], reverse=True) + file_paths = [path for path, _ in matches] + + return "\n".join(file_paths) + + @tool + def grep_search( # noqa: D417 + pattern: str, + path: str = "/", + include: str | None = None, + output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches", + state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment] + ) -> str: + """Fast content search tool that works with any codebase size. + + Searches file contents using regular expressions. Supports full regex + syntax and filters files by pattern with the include parameter. + + Args: + pattern: The regular expression pattern to search for in file contents. + path: The directory to search in. If not specified, searches from root. + include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}"). + output_mode: Output format: + - "files_with_matches": Only file paths containing matches (default) + - "content": Matching lines with file:line:content format + - "count": Count of matches per file + + Returns: + Search results formatted according to output_mode. Returns "No matches + found" if no results. + """ + # Normalize base path + base_path = path if path.startswith("/") else "/" + path + + # Compile regex pattern (for validation) + try: + regex = re.compile(pattern) + except re.error as e: + return f"Invalid regex pattern: {e}" + + if include and not _is_valid_include_pattern(include): + return "Invalid include pattern" + + # Search files + files = cast("dict[str, Any]", state.get(self.state_key, {})) + results: dict[str, list[tuple[int, str]]] = {} + + for file_path, file_data in files.items(): + if not file_path.startswith(base_path): + continue + + # Check include filter + if include: + basename = Path(file_path).name + if not _match_include_pattern(basename, include): + continue + + # Search file content + for line_num, line in enumerate(file_data["content"], 1): + if regex.search(line): + if file_path not in results: + results[file_path] = [] + results[file_path].append((line_num, line)) + + if not results: + return "No matches found" + + # Format output based on mode + return self._format_grep_results(results, output_mode) + + self.glob_search = glob_search + self.grep_search = grep_search + self.tools = [glob_search, grep_search] + + def _format_grep_results( + self, + results: dict[str, list[tuple[int, str]]], + output_mode: str, + ) -> str: + """Format grep results based on output mode.""" + if output_mode == "files_with_matches": + # Just return file paths + return "\n".join(sorted(results.keys())) + + if output_mode == "content": + # Return file:line:content format + lines = [] + for file_path in sorted(results.keys()): + for line_num, line in results[file_path]: + lines.append(f"{file_path}:{line_num}:{line}") + return "\n".join(lines) + + if output_mode == "count": + # Return file:count format + lines = [] + for file_path in sorted(results.keys()): + count = len(results[file_path]) + lines.append(f"{file_path}:{count}") + return "\n".join(lines) + + # Default to files_with_matches + return "\n".join(sorted(results.keys())) + + +class FilesystemFileSearchMiddleware(AgentMiddleware): + """Provides Glob and Grep search over filesystem files. + + This middleware adds two tools that search through local filesystem: + - Glob: Fast file pattern matching by file path + - Grep: Fast content search using ripgrep or Python fallback + + Example: + ```python + from langchain.agents import create_agent + from langchain.agents.middleware import ( + FilesystemTextEditorToolMiddleware, + FilesystemFileSearchMiddleware, + ) + + agent = create_agent( + model=model, + tools=[], + middleware=[ + FilesystemTextEditorToolMiddleware(root_path="/workspace"), + FilesystemFileSearchMiddleware(root_path="/workspace"), + ], + ) + ``` + """ + + def __init__( + self, + *, + root_path: str, + use_ripgrep: bool = True, + max_file_size_mb: int = 10, + ) -> None: + """Initialize the search middleware. + + Args: + root_path: Root directory to search. + use_ripgrep: Whether to use ripgrep for search (default: True). + Falls back to Python if ripgrep unavailable. + max_file_size_mb: Maximum file size to search in MB (default: 10). + """ + self.root_path = Path(root_path).resolve() + self.use_ripgrep = use_ripgrep + self.max_file_size_bytes = max_file_size_mb * 1024 * 1024 + + # Create tool instances as closures that capture self + @tool + def glob_search(pattern: str, path: str = "/") -> str: + """Fast file pattern matching tool that works with any codebase size. + + Supports glob patterns like **/*.js or src/**/*.ts. + Returns matching file paths sorted by modification time. + Use this tool when you need to find files by name patterns. + + Args: + pattern: The glob pattern to match files against. + path: The directory to search in. If not specified, searches from root. + + Returns: + Newline-separated list of matching file paths, sorted by modification + time (most recently modified first). Returns "No files found" if no + matches. + """ + try: + base_full = self._validate_and_resolve_path(path) + except ValueError: + return "No files found" + + if not base_full.exists() or not base_full.is_dir(): + return "No files found" + + # Use pathlib glob + matching: list[tuple[str, str]] = [] + for match in base_full.glob(pattern): + if match.is_file(): + # Convert to virtual path + virtual_path = "/" + str(match.relative_to(self.root_path)) + stat = match.stat() + modified_at = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat() + matching.append((virtual_path, modified_at)) + + if not matching: + return "No files found" + + file_paths = [p for p, _ in matching] + return "\n".join(file_paths) + + @tool + def grep_search( + pattern: str, + path: str = "/", + include: str | None = None, + output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches", + ) -> str: + """Fast content search tool that works with any codebase size. + + Searches file contents using regular expressions. Supports full regex + syntax and filters files by pattern with the include parameter. + + Args: + pattern: The regular expression pattern to search for in file contents. + path: The directory to search in. If not specified, searches from root. + include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}"). + output_mode: Output format: + - "files_with_matches": Only file paths containing matches (default) + - "content": Matching lines with file:line:content format + - "count": Count of matches per file + + Returns: + Search results formatted according to output_mode. Returns "No matches + found" if no results. + """ + # Compile regex pattern (for validation) + try: + re.compile(pattern) + except re.error as e: + return f"Invalid regex pattern: {e}" + + if include and not _is_valid_include_pattern(include): + return "Invalid include pattern" + + # Try ripgrep first if enabled + results = None + if self.use_ripgrep: + with suppress( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): + results = self._ripgrep_search(pattern, path, include) + + # Python fallback if ripgrep failed or is disabled + if results is None: + results = self._python_search(pattern, path, include) + + if not results: + return "No matches found" + + # Format output based on mode + return self._format_grep_results(results, output_mode) + + self.glob_search = glob_search + self.grep_search = grep_search + self.tools = [glob_search, grep_search] + + def _validate_and_resolve_path(self, path: str) -> Path: + """Validate and resolve a virtual path to filesystem path.""" + # Normalize path + if not path.startswith("/"): + path = "/" + path + + # Check for path traversal + if ".." in path or "~" in path: + msg = "Path traversal not allowed" + raise ValueError(msg) + + # Convert virtual path to filesystem path + relative = path.lstrip("/") + full_path = (self.root_path / relative).resolve() + + # Ensure path is within root + try: + full_path.relative_to(self.root_path) + except ValueError: + msg = f"Path outside root directory: {path}" + raise ValueError(msg) from None + + return full_path + + def _ripgrep_search( + self, pattern: str, base_path: str, include: str | None + ) -> dict[str, list[tuple[int, str]]]: + """Search using ripgrep subprocess.""" + try: + base_full = self._validate_and_resolve_path(base_path) + except ValueError: + return {} + + if not base_full.exists(): + return {} + + # Build ripgrep command + cmd = ["rg", "--json"] + + if include: + # Convert glob pattern to ripgrep glob + cmd.extend(["--glob", include]) + + cmd.extend(["--", pattern, str(base_full)]) + + try: + result = subprocess.run( # noqa: S603 + cmd, + capture_output=True, + text=True, + timeout=30, + check=False, + ) + except (subprocess.TimeoutExpired, FileNotFoundError): + # Fallback to Python search if ripgrep unavailable or times out + return self._python_search(pattern, base_path, include) + + # Parse ripgrep JSON output + results: dict[str, list[tuple[int, str]]] = {} + for line in result.stdout.splitlines(): + try: + data = json.loads(line) + if data["type"] == "match": + path = data["data"]["path"]["text"] + # Convert to virtual path + virtual_path = "/" + str(Path(path).relative_to(self.root_path)) + line_num = data["data"]["line_number"] + line_text = data["data"]["lines"]["text"].rstrip("\n") + + if virtual_path not in results: + results[virtual_path] = [] + results[virtual_path].append((line_num, line_text)) + except (json.JSONDecodeError, KeyError): + continue + + return results + + def _python_search( + self, pattern: str, base_path: str, include: str | None + ) -> dict[str, list[tuple[int, str]]]: + """Search using Python regex (fallback).""" + try: + base_full = self._validate_and_resolve_path(base_path) + except ValueError: + return {} + + if not base_full.exists(): + return {} + + regex = re.compile(pattern) + results: dict[str, list[tuple[int, str]]] = {} + + # Walk directory tree + for file_path in base_full.rglob("*"): + if not file_path.is_file(): + continue + + # Check include filter + if include and not _match_include_pattern(file_path.name, include): + continue + + # Skip files that are too large + if file_path.stat().st_size > self.max_file_size_bytes: + continue + + try: + content = file_path.read_text() + except (UnicodeDecodeError, PermissionError): + continue + + # Search content + for line_num, line in enumerate(content.splitlines(), 1): + if regex.search(line): + virtual_path = "/" + str(file_path.relative_to(self.root_path)) + if virtual_path not in results: + results[virtual_path] = [] + results[virtual_path].append((line_num, line)) + + return results + + def _format_grep_results( + self, + results: dict[str, list[tuple[int, str]]], + output_mode: str, + ) -> str: + """Format grep results based on output mode.""" + if output_mode == "files_with_matches": + # Just return file paths + return "\n".join(sorted(results.keys())) + + if output_mode == "content": + # Return file:line:content format + lines = [] + for file_path in sorted(results.keys()): + for line_num, line in results[file_path]: + lines.append(f"{file_path}:{line_num}:{line}") + return "\n".join(lines) + + if output_mode == "count": + # Return file:count format + lines = [] + for file_path in sorted(results.keys()): + count = len(results[file_path]) + lines.append(f"{file_path}:{count}") + return "\n".join(lines) + + # Default to files_with_matches + return "\n".join(sorted(results.keys())) + + +__all__ = [ + "FilesystemFileSearchMiddleware", + "StateFileSearchMiddleware", +] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_anthropic_tools.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_anthropic_tools.py new file mode 100644 index 0000000000000..4eaa055cbce5f --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_anthropic_tools.py @@ -0,0 +1,276 @@ +"""Unit tests for Anthropic text editor and memory tool middleware.""" + +import pytest +from langchain.agents.middleware.anthropic_tools import ( + AnthropicToolsState, + StateClaudeMemoryMiddleware, + StateClaudeTextEditorMiddleware, + _validate_path, +) +from langchain_core.messages import ToolMessage +from langgraph.types import Command + + +class TestPathValidation: + """Test path validation and security.""" + + def test_basic_path_normalization(self) -> None: + """Test basic path normalization.""" + assert _validate_path("/foo/bar") == "/foo/bar" + assert _validate_path("foo/bar") == "/foo/bar" + assert _validate_path("/foo//bar") == "/foo/bar" + assert _validate_path("/foo/./bar") == "/foo/bar" + + def test_path_traversal_blocked(self) -> None: + """Test that path traversal attempts are blocked.""" + with pytest.raises(ValueError, match="Path traversal not allowed"): + _validate_path("/foo/../etc/passwd") + + with pytest.raises(ValueError, match="Path traversal not allowed"): + _validate_path("../etc/passwd") + + with pytest.raises(ValueError, match="Path traversal not allowed"): + _validate_path("~/.ssh/id_rsa") + + def test_allowed_prefixes(self) -> None: + """Test path prefix validation.""" + # Should pass + assert ( + _validate_path("/workspace/file.txt", allowed_prefixes=["/workspace"]) + == "/workspace/file.txt" + ) + + # Should fail + with pytest.raises(ValueError, match="Path must start with"): + _validate_path("/etc/passwd", allowed_prefixes=["/workspace"]) + + with pytest.raises(ValueError, match="Path must start with"): + _validate_path("/workspacemalicious/file.txt", allowed_prefixes=["/workspace/"]) + + def test_memories_prefix(self) -> None: + """Test /memories prefix validation for memory tools.""" + assert ( + _validate_path("/memories/notes.txt", allowed_prefixes=["/memories"]) + == "/memories/notes.txt" + ) + + with pytest.raises(ValueError, match="Path must start with"): + _validate_path("/other/notes.txt", allowed_prefixes=["/memories"]) + + +class TestTextEditorMiddleware: + """Test text editor middleware functionality.""" + + def test_middleware_initialization(self) -> None: + """Test middleware initializes correctly.""" + middleware = StateClaudeTextEditorMiddleware() + assert middleware.state_schema == AnthropicToolsState + assert middleware.tool_type == "text_editor_20250728" + assert middleware.tool_name == "str_replace_based_edit_tool" + assert middleware.state_key == "text_editor_files" + + # With path restrictions + middleware = StateClaudeTextEditorMiddleware(allowed_path_prefixes=["/workspace"]) + assert middleware.allowed_prefixes == ["/workspace"] + + +class TestMemoryMiddleware: + """Test memory middleware functionality.""" + + def test_middleware_initialization(self) -> None: + """Test middleware initializes correctly.""" + middleware = StateClaudeMemoryMiddleware() + assert middleware.state_schema == AnthropicToolsState + assert middleware.tool_type == "memory_20250818" + assert middleware.tool_name == "memory" + assert middleware.state_key == "memory_files" + assert middleware.system_prompt # Should have default prompt + + def test_custom_system_prompt(self) -> None: + """Test custom system prompt can be set.""" + custom_prompt = "Custom memory instructions" + middleware = StateClaudeMemoryMiddleware(system_prompt=custom_prompt) + assert middleware.system_prompt == custom_prompt + + +class TestFileOperations: + """Test file operation implementations via wrap_tool_call.""" + + def test_view_operation(self) -> None: + """Test view command execution.""" + middleware = StateClaudeTextEditorMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/test.txt": { + "content": ["line1", "line2", "line3"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + } + }, + } + + args = {"command": "view", "path": "/test.txt"} + result = middleware._handle_view(args, state, "test_id") + + assert isinstance(result, Command) + assert result.update is not None + messages = result.update.get("messages", []) + assert len(messages) == 1 + assert isinstance(messages[0], ToolMessage) + assert messages[0].content == "1|line1\n2|line2\n3|line3" + assert messages[0].tool_call_id == "test_id" + + def test_create_operation(self) -> None: + """Test create command execution.""" + middleware = StateClaudeTextEditorMiddleware() + + state: AnthropicToolsState = {"messages": []} + + args = {"command": "create", "path": "/test.txt", "file_text": "line1\nline2"} + result = middleware._handle_create(args, state, "test_id") + + assert isinstance(result, Command) + assert result.update is not None + files = result.update.get("text_editor_files", {}) + assert "/test.txt" in files + assert files["/test.txt"]["content"] == ["line1", "line2"] + + def test_path_prefix_enforcement(self) -> None: + """Test that path prefixes are enforced.""" + middleware = StateClaudeTextEditorMiddleware(allowed_path_prefixes=["/workspace"]) + + state: AnthropicToolsState = {"messages": []} + + # Should fail with /etc/passwd + args = {"command": "create", "path": "/etc/passwd", "file_text": "test"} + + try: + middleware._handle_create(args, state, "test_id") + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Path must start with" in str(e) + + def test_memories_prefix_enforcement(self) -> None: + """Test that /memories prefix is enforced for memory middleware.""" + middleware = StateClaudeMemoryMiddleware() + + state: AnthropicToolsState = {"messages": []} + + # Should fail with /other/path + args = {"command": "create", "path": "/other/path.txt", "file_text": "test"} + + try: + middleware._handle_create(args, state, "test_id") + assert False, "Should have raised ValueError" + except ValueError as e: + assert "/memories" in str(e) + + def test_str_replace_operation(self) -> None: + """Test str_replace command execution.""" + middleware = StateClaudeTextEditorMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/test.txt": { + "content": ["Hello world", "Goodbye world"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + } + }, + } + + args = { + "command": "str_replace", + "path": "/test.txt", + "old_str": "world", + "new_str": "universe", + } + result = middleware._handle_str_replace(args, state, "test_id") + + assert isinstance(result, Command) + files = result.update.get("text_editor_files", {}) + # Should only replace first occurrence + assert files["/test.txt"]["content"] == ["Hello universe", "Goodbye world"] + + def test_insert_operation(self) -> None: + """Test insert command execution.""" + middleware = StateClaudeTextEditorMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/test.txt": { + "content": ["line1", "line2"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + } + }, + } + + args = { + "command": "insert", + "path": "/test.txt", + "insert_line": 0, + "new_str": "inserted", + } + result = middleware._handle_insert(args, state, "test_id") + + assert isinstance(result, Command) + files = result.update.get("text_editor_files", {}) + assert files["/test.txt"]["content"] == ["inserted", "line1", "line2"] + + def test_delete_operation(self) -> None: + """Test delete command execution (memory only).""" + middleware = StateClaudeMemoryMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "memory_files": { + "/memories/test.txt": { + "content": ["line1"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + } + }, + } + + args = {"command": "delete", "path": "/memories/test.txt"} + result = middleware._handle_delete(args, state, "test_id") + + assert isinstance(result, Command) + files = result.update.get("memory_files", {}) + # Deleted files are marked as None in state + assert files.get("/memories/test.txt") is None + + def test_rename_operation(self) -> None: + """Test rename command execution (memory only).""" + middleware = StateClaudeMemoryMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "memory_files": { + "/memories/old.txt": { + "content": ["line1"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + } + }, + } + + args = { + "command": "rename", + "old_path": "/memories/old.txt", + "new_path": "/memories/new.txt", + } + result = middleware._handle_rename(args, state, "test_id") + + assert isinstance(result, Command) + files = result.update.get("memory_files", {}) + # Old path is marked as None (deleted) + assert files.get("/memories/old.txt") is None + # New path has the file data + assert files.get("/memories/new.txt") is not None + assert files["/memories/new.txt"]["content"] == ["line1"] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_file_search.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_file_search.py new file mode 100644 index 0000000000000..7c567730ae755 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_file_search.py @@ -0,0 +1,530 @@ +"""Unit tests for file search middleware.""" + +from pathlib import Path +from typing import Any + +import pytest +from langchain.agents.middleware.anthropic_tools import AnthropicToolsState +from langchain.agents.middleware.file_search import ( + FilesystemFileSearchMiddleware, + StateFileSearchMiddleware, +) +from langchain_core.messages import ToolMessage + + +class TestSearchMiddlewareInitialization: + """Test search middleware initialization.""" + + def test_middleware_initialization(self) -> None: + """Test middleware initializes correctly.""" + middleware = StateFileSearchMiddleware() + assert middleware.state_schema == AnthropicToolsState + assert middleware.state_key == "text_editor_files" + + def test_custom_state_key(self) -> None: + """Test middleware with custom state key.""" + middleware = StateFileSearchMiddleware(state_key="memory_files") + assert middleware.state_key == "memory_files" + + +class TestGlobSearch: + """Test Glob file pattern matching.""" + + def test_glob_basic_pattern(self) -> None: + """Test basic glob pattern matching.""" + middleware = StateFileSearchMiddleware() + + test_state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["print('hello')"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/src/utils.py": { + "content": ["def helper(): pass"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/README.md": { + "content": ["# Project"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + # Call tool function directly (state is injected in real usage) + result = middleware.glob_search.func(pattern="*.py", state=test_state) + + assert isinstance(result, str) + assert "/src/main.py" in result + assert "/src/utils.py" in result + assert "/README.md" not in result + + def test_glob_recursive_pattern(self) -> None: + """Test recursive glob pattern matching.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/src/utils/helper.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/tests/test_main.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.glob_search.func(pattern="**/*.py", state=state) + + assert isinstance(result, str) + lines = result.split("\n") + assert len(lines) == 3 + assert all(".py" in line for line in lines) + + def test_glob_with_base_path(self) -> None: + """Test glob with base path restriction.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/tests/test.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.glob_search.func(pattern="**/*.py", path="/src", state=state) + + assert isinstance(result, str) + assert "/src/main.py" in result + assert "/tests/test.py" not in result + + def test_glob_no_matches(self) -> None: + """Test glob with no matching files.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.glob_search.func(pattern="*.ts", state=state) + + assert isinstance(result, str) + assert result == "No files found" + + def test_glob_sorts_by_modified_time(self) -> None: + """Test that glob results are sorted by modification time.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/old.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/new.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-02T00:00:00", + }, + }, + } + + result = middleware.glob_search.func(pattern="*.py", state=state) + + lines = result.split("\n") + # Most recent first + assert lines[0] == "/new.py" + assert lines[1] == "/old.py" + + +class TestGrepSearch: + """Test Grep content search.""" + + def test_grep_files_with_matches_mode(self) -> None: + """Test grep with files_with_matches output mode.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["def foo():", " pass"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/src/utils.py": { + "content": ["def bar():", " return None"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/README.md": { + "content": ["# Documentation", "No code here"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) + + assert isinstance(result, str) + assert "/src/main.py" in result + assert "/src/utils.py" in result + assert "/README.md" not in result + # Should only have file paths, not line content + + def test_grep_invalid_include_pattern(self) -> None: + """Return error when include glob is invalid.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["def foo():"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + } + }, + } + + result = middleware.grep_search.func(pattern=r"def", include="*.{py", state=state) + + assert result == "Invalid include pattern" + + +class TestFilesystemGrepSearch: + """Tests for filesystem-backed grep search.""" + + def test_grep_invalid_include_pattern(self, tmp_path: Path) -> None: + """Return error when include glob cannot be parsed.""" + + (tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="print", include="*.{py") + + assert result == "Invalid include pattern" + + def test_ripgrep_command_uses_literal_pattern( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Ensure ripgrep receives pattern after ``--`` to avoid option parsing.""" + + (tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=True) + + captured: dict[str, list[str]] = {} + + class DummyResult: + stdout = "" + + def fake_run(*args: Any, **kwargs: Any) -> DummyResult: + cmd = args[0] + captured["cmd"] = cmd + return DummyResult() + + monkeypatch.setattr("langchain.agents.middleware.file_search.subprocess.run", fake_run) + + middleware._ripgrep_search("--pattern", "/", None) + + assert "cmd" in captured + cmd = captured["cmd"] + assert cmd[:2] == ["rg", "--json"] + assert "--" in cmd + separator_index = cmd.index("--") + assert cmd[separator_index + 1] == "--pattern" + + def test_grep_content_mode(self) -> None: + """Test grep with content output mode.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["def foo():", " pass", "def bar():"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func( + pattern=r"def \w+\(\):", output_mode="content", state=state + ) + + assert isinstance(result, str) + lines = result.split("\n") + assert len(lines) == 2 + assert lines[0] == "/src/main.py:1:def foo():" + assert lines[1] == "/src/main.py:3:def bar():" + + def test_grep_count_mode(self) -> None: + """Test grep with count output mode.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["TODO: fix this", "print('hello')", "TODO: add tests"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/src/utils.py": { + "content": ["TODO: implement"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern=r"TODO", output_mode="count", state=state) + + assert isinstance(result, str) + lines = result.split("\n") + assert "/src/main.py:2" in lines + assert "/src/utils.py:1" in lines + + def test_grep_with_include_filter(self) -> None: + """Test grep with include file pattern filter.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["import os"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/src/main.ts": { + "content": ["import os from 'os'"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern="import", include="*.py", state=state) + + assert isinstance(result, str) + assert "/src/main.py" in result + assert "/src/main.ts" not in result + + def test_grep_with_brace_expansion_filter(self) -> None: + """Test grep with brace expansion in include filter.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.ts": { + "content": ["const x = 1"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/src/App.tsx": { + "content": ["const y = 2"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/src/main.py": { + "content": ["z = 3"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern="const", include="*.{ts,tsx}", state=state) + + assert isinstance(result, str) + assert "/src/main.ts" in result + assert "/src/App.tsx" in result + assert "/src/main.py" not in result + + def test_grep_with_base_path(self) -> None: + """Test grep with base path restriction.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["import foo"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + "/tests/test.py": { + "content": ["import foo"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern="import", path="/src", state=state) + + assert isinstance(result, str) + assert "/src/main.py" in result + assert "/tests/test.py" not in result + + def test_grep_no_matches(self) -> None: + """Test grep with no matching content.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["print('hello')"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern=r"TODO", state=state) + + assert isinstance(result, str) + assert result == "No matches found" + + def test_grep_invalid_regex(self) -> None: + """Test grep with invalid regex pattern.""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": {}, + } + + result = middleware.grep_search.func(pattern=r"[unclosed", state=state) + + assert isinstance(result, str) + assert "Invalid regex pattern" in result + + +class TestSearchWithDifferentBackends: + """Test searching with different backend configurations.""" + + def test_glob_default_backend(self) -> None: + """Test that glob searches the default backend (text_editor_files).""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + "memory_files": { + "/memories/notes.txt": { + "content": [], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.glob_search.func(pattern="**/*", state=state) + + assert isinstance(result, str) + assert "/src/main.py" in result + # Should NOT find memory_files since default backend is text_editor_files + assert "/memories/notes.txt" not in result + + def test_grep_default_backend(self) -> None: + """Test that grep searches the default backend (text_editor_files).""" + middleware = StateFileSearchMiddleware() + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["TODO: implement"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + "memory_files": { + "/memories/tasks.txt": { + "content": ["TODO: review"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern=r"TODO", state=state) + + assert isinstance(result, str) + assert "/src/main.py" in result + # Should NOT find memory_files since default backend is text_editor_files + assert "/memories/tasks.txt" not in result + + def test_search_with_single_store(self) -> None: + """Test searching with a specific state key.""" + middleware = StateFileSearchMiddleware(state_key="text_editor_files") + + state: AnthropicToolsState = { + "messages": [], + "text_editor_files": { + "/src/main.py": { + "content": ["code"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + "memory_files": { + "/memories/notes.txt": { + "content": ["notes"], + "created_at": "2025-01-01T00:00:00", + "modified_at": "2025-01-01T00:00:00", + }, + }, + } + + result = middleware.grep_search.func(pattern=r".*", state=state) + + assert isinstance(result, str) + assert "/src/main.py" in result + assert "/memories/notes.txt" not in result