Skip to content

Commit 4d48faa

Browse files
committed
Guard against malformed or malicious include patterns
1 parent 672b8ec commit 4d48faa

File tree

3 files changed

+157
-45
lines changed

3 files changed

+157
-45
lines changed

libs/langchain_v1/langchain/agents/middleware/anthropic_tools.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@
1111
from pathlib import Path
1212
from typing import TYPE_CHECKING, Annotated, Any, cast
1313

14-
from langchain_core.messages import AIMessage, ToolMessage
14+
from langchain_core.messages import ToolMessage
1515
from langgraph.types import Command
1616
from typing_extensions import NotRequired, TypedDict
1717

18-
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
18+
from langchain.agents.middleware.types import (
19+
AgentMiddleware,
20+
AgentState,
21+
ModelRequest,
22+
ModelResponse,
23+
)
1924

2025
if TYPE_CHECKING:
2126
from collections.abc import Callable, Sequence
@@ -182,8 +187,8 @@ def __init__(
182187
def wrap_model_call(
183188
self,
184189
request: ModelRequest,
185-
handler: Callable[[ModelRequest], AIMessage],
186-
) -> AIMessage:
190+
handler: Callable[[ModelRequest], ModelResponse],
191+
) -> ModelResponse:
187192
"""Inject tool and optional system prompt."""
188193
# Add tool
189194
tools = list(request.tools or [])
@@ -610,8 +615,8 @@ def __init__(
610615
def wrap_model_call(
611616
self,
612617
request: ModelRequest,
613-
handler: Callable[[ModelRequest], AIMessage],
614-
) -> AIMessage:
618+
handler: Callable[[ModelRequest], ModelResponse],
619+
) -> ModelResponse:
615620
"""Inject tool and optional system prompt."""
616621
# Add tool
617622
tools = list(request.tools or [])

libs/langchain_v1/langchain/agents/middleware/file_search.py

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,70 @@
2121
from langchain.agents.middleware.types import AgentMiddleware
2222

2323

24+
def _expand_include_patterns(pattern: str) -> list[str] | None:
25+
"""Expand brace patterns like ``*.{py,pyi}`` into a list of globs."""
26+
if "}" in pattern and "{" not in pattern:
27+
return None
28+
29+
expanded: list[str] = []
30+
31+
def _expand(current: str) -> None:
32+
start = current.find("{")
33+
if start == -1:
34+
expanded.append(current)
35+
return
36+
37+
end = current.find("}", start)
38+
if end == -1:
39+
raise ValueError
40+
41+
prefix = current[:start]
42+
suffix = current[end + 1 :]
43+
inner = current[start + 1 : end]
44+
if not inner:
45+
raise ValueError
46+
47+
for option in inner.split(","):
48+
_expand(prefix + option + suffix)
49+
50+
try:
51+
_expand(pattern)
52+
except ValueError:
53+
return None
54+
55+
return expanded
56+
57+
58+
def _is_valid_include_pattern(pattern: str) -> bool:
59+
"""Validate glob pattern used for include filters."""
60+
if not pattern:
61+
return False
62+
63+
if any(char in pattern for char in ("\x00", "\n", "\r")):
64+
return False
65+
66+
expanded = _expand_include_patterns(pattern)
67+
if expanded is None:
68+
return False
69+
70+
try:
71+
for candidate in expanded:
72+
re.compile(fnmatch.translate(candidate))
73+
except re.error:
74+
return False
75+
76+
return True
77+
78+
79+
def _match_include_pattern(basename: str, pattern: str) -> bool:
80+
"""Return True if the basename matches the include pattern."""
81+
expanded = _expand_include_patterns(pattern)
82+
if not expanded:
83+
return False
84+
85+
return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded)
86+
87+
2488
class StateFileSearchMiddleware(AgentMiddleware):
2589
"""Provides Glob and Grep search over state-based files.
2690
@@ -159,6 +223,9 @@ def grep_search( # noqa: D417
159223
except re.error as e:
160224
return f"Invalid regex pattern: {e}"
161225

226+
if include and not _is_valid_include_pattern(include):
227+
return "Invalid include pattern"
228+
162229
# Search files
163230
files = cast("dict[str, Any]", state.get(self.state_key, {}))
164231
results: dict[str, list[tuple[int, str]]] = {}
@@ -170,7 +237,7 @@ def grep_search( # noqa: D417
170237
# Check include filter
171238
if include:
172239
basename = Path(file_path).name
173-
if not self._match_include(basename, include):
240+
if not _match_include_pattern(basename, include):
174241
continue
175242

176243
# Search file content
@@ -190,23 +257,6 @@ def grep_search( # noqa: D417
190257
self.grep_search = grep_search
191258
self.tools = [glob_search, grep_search]
192259

193-
def _match_include(self, basename: str, pattern: str) -> bool:
194-
"""Match filename against include pattern."""
195-
# Handle brace expansion {a,b,c}
196-
if "{" in pattern and "}" in pattern:
197-
start = pattern.index("{")
198-
end = pattern.index("}")
199-
prefix = pattern[:start]
200-
suffix = pattern[end + 1 :]
201-
alternatives = pattern[start + 1 : end].split(",")
202-
203-
for alt in alternatives:
204-
expanded = prefix + alt + suffix
205-
if fnmatch.fnmatch(basename, expanded):
206-
return True
207-
return False
208-
return fnmatch.fnmatch(basename, pattern)
209-
210260
def _format_grep_results(
211261
self,
212262
results: dict[str, list[tuple[int, str]]],
@@ -355,6 +405,9 @@ def grep_search(
355405
except re.error as e:
356406
return f"Invalid regex pattern: {e}"
357407

408+
if include and not _is_valid_include_pattern(include):
409+
return "Invalid include pattern"
410+
358411
# Try ripgrep first if enabled
359412
results = None
360413
if self.use_ripgrep:
@@ -416,12 +469,14 @@ def _ripgrep_search(
416469
return {}
417470

418471
# Build ripgrep command
419-
cmd = ["rg", "--json", pattern, str(base_full)]
472+
cmd = ["rg", "--json"]
420473

421474
if include:
422475
# Convert glob pattern to ripgrep glob
423476
cmd.extend(["--glob", include])
424477

478+
cmd.extend(["--", pattern, str(base_full)])
479+
425480
try:
426481
result = subprocess.run( # noqa: S603
427482
cmd,
@@ -475,7 +530,7 @@ def _python_search(
475530
continue
476531

477532
# Check include filter
478-
if include and not self._match_include(file_path.name, include):
533+
if include and not _match_include_pattern(file_path.name, include):
479534
continue
480535

481536
# Skip files that are too large
@@ -497,23 +552,6 @@ def _python_search(
497552

498553
return results
499554

500-
def _match_include(self, basename: str, pattern: str) -> bool:
501-
"""Match filename against include pattern."""
502-
# Handle brace expansion {a,b,c}
503-
if "{" in pattern and "}" in pattern:
504-
start = pattern.index("{")
505-
end = pattern.index("}")
506-
prefix = pattern[:start]
507-
suffix = pattern[end + 1 :]
508-
alternatives = pattern[start + 1 : end].split(",")
509-
510-
for alt in alternatives:
511-
expanded = prefix + alt + suffix
512-
if fnmatch.fnmatch(basename, expanded):
513-
return True
514-
return False
515-
return fnmatch.fnmatch(basename, pattern)
516-
517555
def _format_grep_results(
518556
self,
519557
results: dict[str, list[tuple[int, str]]],

libs/langchain_v1/tests/unit_tests/agents/middleware/test_file_search.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
"""Unit tests for file search middleware."""
22

3+
from pathlib import Path
4+
from typing import Any
5+
36
import pytest
47
from langchain.agents.middleware.anthropic_tools import AnthropicToolsState
5-
from langchain.agents.middleware.file_search import StateFileSearchMiddleware
8+
from langchain.agents.middleware.file_search import (
9+
FilesystemFileSearchMiddleware,
10+
StateFileSearchMiddleware,
11+
)
612
from langchain_core.messages import ToolMessage
713

814

@@ -198,7 +204,70 @@ def test_grep_files_with_matches_mode(self) -> None:
198204
assert "/src/utils.py" in result
199205
assert "/README.md" not in result
200206
# Should only have file paths, not line content
201-
assert "def foo():" not in result
207+
208+
def test_grep_invalid_include_pattern(self) -> None:
209+
"""Return error when include glob is invalid."""
210+
middleware = StateFileSearchMiddleware()
211+
212+
state: AnthropicToolsState = {
213+
"messages": [],
214+
"text_editor_files": {
215+
"/src/main.py": {
216+
"content": ["def foo():"],
217+
"created_at": "2025-01-01T00:00:00",
218+
"modified_at": "2025-01-01T00:00:00",
219+
}
220+
},
221+
}
222+
223+
result = middleware.grep_search.func(pattern=r"def", include="*.{py", state=state)
224+
225+
assert result == "Invalid include pattern"
226+
227+
228+
class TestFilesystemGrepSearch:
229+
"""Tests for filesystem-backed grep search."""
230+
231+
def test_grep_invalid_include_pattern(self, tmp_path: Path) -> None:
232+
"""Return error when include glob cannot be parsed."""
233+
234+
(tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8")
235+
236+
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
237+
238+
result = middleware.grep_search.func(pattern="print", include="*.{py")
239+
240+
assert result == "Invalid include pattern"
241+
242+
def test_ripgrep_command_uses_literal_pattern(
243+
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
244+
) -> None:
245+
"""Ensure ripgrep receives pattern after ``--`` to avoid option parsing."""
246+
247+
(tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8")
248+
249+
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=True)
250+
251+
captured: dict[str, list[str]] = {}
252+
253+
class DummyResult:
254+
stdout = ""
255+
256+
def fake_run(*args: Any, **kwargs: Any) -> DummyResult:
257+
cmd = args[0]
258+
captured["cmd"] = cmd
259+
return DummyResult()
260+
261+
monkeypatch.setattr("langchain.agents.middleware.file_search.subprocess.run", fake_run)
262+
263+
middleware._ripgrep_search("--pattern", "/", None)
264+
265+
assert "cmd" in captured
266+
cmd = captured["cmd"]
267+
assert cmd[:2] == ["rg", "--json"]
268+
assert "--" in cmd
269+
separator_index = cmd.index("--")
270+
assert cmd[separator_index + 1] == "--pattern"
202271

203272
def test_grep_content_mode(self) -> None:
204273
"""Test grep with content output mode."""

0 commit comments

Comments
 (0)