2121from langchain .agents .middleware .types import AgentMiddleware
2222
2323
24+ def _expand_include_patterns (pattern : str ) -> list [str ] | None :
25+ """Expand brace patterns like ``*.{py,pyi}`` into a list of globs."""
26+ if "}" in pattern and "{" not in pattern :
27+ return None
28+
29+ expanded : list [str ] = []
30+
31+ def _expand (current : str ) -> None :
32+ start = current .find ("{" )
33+ if start == - 1 :
34+ expanded .append (current )
35+ return
36+
37+ end = current .find ("}" , start )
38+ if end == - 1 :
39+ raise ValueError
40+
41+ prefix = current [:start ]
42+ suffix = current [end + 1 :]
43+ inner = current [start + 1 : end ]
44+ if not inner :
45+ raise ValueError
46+
47+ for option in inner .split ("," ):
48+ _expand (prefix + option + suffix )
49+
50+ try :
51+ _expand (pattern )
52+ except ValueError :
53+ return None
54+
55+ return expanded
56+
57+
58+ def _is_valid_include_pattern (pattern : str ) -> bool :
59+ """Validate glob pattern used for include filters."""
60+ if not pattern :
61+ return False
62+
63+ if any (char in pattern for char in ("\x00 " , "\n " , "\r " )):
64+ return False
65+
66+ expanded = _expand_include_patterns (pattern )
67+ if expanded is None :
68+ return False
69+
70+ try :
71+ for candidate in expanded :
72+ re .compile (fnmatch .translate (candidate ))
73+ except re .error :
74+ return False
75+
76+ return True
77+
78+
79+ def _match_include_pattern (basename : str , pattern : str ) -> bool :
80+ """Return True if the basename matches the include pattern."""
81+ expanded = _expand_include_patterns (pattern )
82+ if not expanded :
83+ return False
84+
85+ return any (fnmatch .fnmatch (basename , candidate ) for candidate in expanded )
86+
87+
2488class StateFileSearchMiddleware (AgentMiddleware ):
2589 """Provides Glob and Grep search over state-based files.
2690
@@ -159,6 +223,9 @@ def grep_search( # noqa: D417
159223 except re .error as e :
160224 return f"Invalid regex pattern: { e } "
161225
226+ if include and not _is_valid_include_pattern (include ):
227+ return "Invalid include pattern"
228+
162229 # Search files
163230 files = cast ("dict[str, Any]" , state .get (self .state_key , {}))
164231 results : dict [str , list [tuple [int , str ]]] = {}
@@ -170,7 +237,7 @@ def grep_search( # noqa: D417
170237 # Check include filter
171238 if include :
172239 basename = Path (file_path ).name
173- if not self . _match_include (basename , include ):
240+ if not _match_include_pattern (basename , include ):
174241 continue
175242
176243 # Search file content
@@ -190,23 +257,6 @@ def grep_search( # noqa: D417
190257 self .grep_search = grep_search
191258 self .tools = [glob_search , grep_search ]
192259
193- def _match_include (self , basename : str , pattern : str ) -> bool :
194- """Match filename against include pattern."""
195- # Handle brace expansion {a,b,c}
196- if "{" in pattern and "}" in pattern :
197- start = pattern .index ("{" )
198- end = pattern .index ("}" )
199- prefix = pattern [:start ]
200- suffix = pattern [end + 1 :]
201- alternatives = pattern [start + 1 : end ].split ("," )
202-
203- for alt in alternatives :
204- expanded = prefix + alt + suffix
205- if fnmatch .fnmatch (basename , expanded ):
206- return True
207- return False
208- return fnmatch .fnmatch (basename , pattern )
209-
210260 def _format_grep_results (
211261 self ,
212262 results : dict [str , list [tuple [int , str ]]],
@@ -355,6 +405,9 @@ def grep_search(
355405 except re .error as e :
356406 return f"Invalid regex pattern: { e } "
357407
408+ if include and not _is_valid_include_pattern (include ):
409+ return "Invalid include pattern"
410+
358411 # Try ripgrep first if enabled
359412 results = None
360413 if self .use_ripgrep :
@@ -416,12 +469,14 @@ def _ripgrep_search(
416469 return {}
417470
418471 # Build ripgrep command
419- cmd = ["rg" , "--json" , pattern , str ( base_full ) ]
472+ cmd = ["rg" , "--json" ]
420473
421474 if include :
422475 # Convert glob pattern to ripgrep glob
423476 cmd .extend (["--glob" , include ])
424477
478+ cmd .extend (["--" , pattern , str (base_full )])
479+
425480 try :
426481 result = subprocess .run ( # noqa: S603
427482 cmd ,
@@ -475,7 +530,7 @@ def _python_search(
475530 continue
476531
477532 # Check include filter
478- if include and not self . _match_include (file_path .name , include ):
533+ if include and not _match_include_pattern (file_path .name , include ):
479534 continue
480535
481536 # Skip files that are too large
@@ -497,23 +552,6 @@ def _python_search(
497552
498553 return results
499554
500- def _match_include (self , basename : str , pattern : str ) -> bool :
501- """Match filename against include pattern."""
502- # Handle brace expansion {a,b,c}
503- if "{" in pattern and "}" in pattern :
504- start = pattern .index ("{" )
505- end = pattern .index ("}" )
506- prefix = pattern [:start ]
507- suffix = pattern [end + 1 :]
508- alternatives = pattern [start + 1 : end ].split ("," )
509-
510- for alt in alternatives :
511- expanded = prefix + alt + suffix
512- if fnmatch .fnmatch (basename , expanded ):
513- return True
514- return False
515- return fnmatch .fnmatch (basename , pattern )
516-
517555 def _format_grep_results (
518556 self ,
519557 results : dict [str , list [tuple [int , str ]]],
0 commit comments