Skip to content

Commit d9466f0

Browse files
committed
Add functools.lru_cache plugin support
- Add lru_cache callback to functools plugin for type validation - Register callbacks in default plugin for decorator and wrapper calls - Support different lru_cache patterns: @lru_cache, @lru_cache(), @lru_cache(maxsize=N) Fixes issue #16261
1 parent b69309b commit d9466f0

File tree

4 files changed

+419
-0
lines changed

4 files changed

+419
-0
lines changed

mypy/plugins/default.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@
5050
)
5151
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
5252
from mypy.plugins.functools import (
53+
functools_lru_cache_callback,
5354
functools_total_ordering_maker_callback,
5455
functools_total_ordering_makers,
56+
lru_cache_wrapper_call_callback,
5557
partial_call_callback,
5658
partial_new_callback,
5759
)
@@ -103,6 +105,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
103105
return create_singledispatch_function_callback
104106
elif fullname == "functools.partial":
105107
return partial_new_callback
108+
elif fullname == "functools.lru_cache":
109+
return functools_lru_cache_callback
106110
elif fullname == "enum.member":
107111
return enum_member_callback
108112
return None
@@ -162,6 +166,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
162166
return call_singledispatch_function_after_register_argument
163167
elif fullname == "functools.partial.__call__":
164168
return partial_call_callback
169+
elif fullname == "functools._lru_cache_wrapper.__call__":
170+
return lru_cache_wrapper_call_callback
165171
return None
166172

167173
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:

mypy/plugins/functools.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
ArgKind,
1717
Argument,
1818
CallExpr,
19+
Decorator,
20+
MemberExpr,
1921
NameExpr,
2022
Var,
2123
)
@@ -25,6 +27,8 @@
2527
AnyType,
2628
CallableType,
2729
Instance,
30+
LiteralType,
31+
NoneType,
2832
Overloaded,
2933
ParamSpecFlavor,
3034
ParamSpecType,
@@ -41,6 +45,7 @@
4145
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
4246

4347
PARTIAL: Final = "functools.partial"
48+
LRU_CACHE: Final = "functools.lru_cache"
4449

4550

4651
class _MethodInfo(NamedTuple):
@@ -393,3 +398,156 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
393398
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
394399

395400
return result
401+
402+
403+
def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
404+
"""Infer a more precise return type for functools.lru_cache decorator"""
405+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
406+
return ctx.default_return_type
407+
408+
# Only handle the very specific case: @lru_cache (without parentheses)
409+
# where a single function is passed directly as the only argument
410+
if (
411+
len(ctx.arg_types) == 1
412+
and len(ctx.arg_types[0]) == 1
413+
and len(ctx.args) == 1
414+
and len(ctx.args[0]) == 1
415+
):
416+
417+
first_arg_type = ctx.arg_types[0][0]
418+
419+
proper_first_arg_type = get_proper_type(first_arg_type)
420+
if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)):
421+
return ctx.default_return_type
422+
423+
# Try to extract callable type
424+
fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type)
425+
if fn_type is not None:
426+
# This is the @lru_cache case (function passed directly)
427+
return fn_type
428+
429+
# For all other cases (parameterized, multiple args, etc.), don't interfere
430+
return ctx.default_return_type
431+
432+
433+
def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
434+
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
435+
if not isinstance(ctx.api, mypy.checker.TypeChecker):
436+
return ctx.default_return_type
437+
438+
# Safety check: ensure we have the required context
439+
if not ctx.context or not ctx.args or not ctx.arg_types:
440+
return ctx.default_return_type
441+
442+
# Try to find the original function signature using AST/symbol table analysis
443+
original_signature = _find_original_function_signature(ctx)
444+
445+
if original_signature is not None:
446+
# Validate the call against the original function signature
447+
actual_args = []
448+
actual_arg_kinds = []
449+
actual_arg_names = []
450+
seen_args = set()
451+
452+
for i, param in enumerate(ctx.args):
453+
for j, a in enumerate(param):
454+
if a in seen_args:
455+
continue
456+
seen_args.add(a)
457+
actual_args.append(a)
458+
actual_arg_kinds.append(ctx.arg_kinds[i][j])
459+
actual_arg_names.append(ctx.arg_names[i][j])
460+
461+
# Check the call against the original signature
462+
result, _ = ctx.api.expr_checker.check_call(
463+
callee=original_signature,
464+
args=actual_args,
465+
arg_kinds=actual_arg_kinds,
466+
arg_names=actual_arg_names,
467+
context=ctx.context,
468+
)
469+
return result
470+
471+
return ctx.default_return_type
472+
473+
474+
def _get_callable_from_decorator(decorator_node: Decorator) -> CallableType | None:
475+
"""Extract the CallableType from a Decorator node if available."""
476+
func_def = decorator_node.func
477+
if isinstance(func_def.type, CallableType):
478+
return func_def.type
479+
return None
480+
481+
482+
def _bind_method(method_type: CallableType, decorator_node: Decorator) -> CallableType:
483+
"""
484+
Bind a method by removing the self parameter for instance methods.
485+
486+
Static and class methods are returned unchanged.
487+
"""
488+
func_def = decorator_node.func
489+
490+
# For instance methods, bind self by removing the first parameter
491+
if not func_def.is_static and not func_def.is_class and method_type.arg_types:
492+
return method_type.copy_modified(
493+
arg_types=method_type.arg_types[1:],
494+
arg_kinds=method_type.arg_kinds[1:],
495+
arg_names=method_type.arg_names[1:],
496+
)
497+
return method_type
498+
499+
500+
def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None:
501+
"""
502+
Find the original function signature from an lru_cache decorated function call.
503+
504+
Returns the CallableType of the original function if found, None otherwise.
505+
"""
506+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
507+
return None
508+
509+
if not isinstance(ctx.context, CallExpr):
510+
return None
511+
512+
callee = ctx.context.callee
513+
514+
# Handle method calls (obj.method() or Class.method())
515+
if isinstance(callee, MemberExpr):
516+
method_name = callee.name
517+
if not method_name:
518+
return None
519+
520+
# Get the type of the object or class being accessed
521+
member_type = ctx.api.expr_checker.accept(callee.expr)
522+
proper_type = get_proper_type(member_type)
523+
524+
if not isinstance(proper_type, Instance):
525+
return None
526+
527+
# Look up the method in the class
528+
class_info = proper_type.type
529+
if method_name not in class_info.names:
530+
return None
531+
532+
symbol = class_info.names[method_name]
533+
if not isinstance(symbol.node, Decorator):
534+
return None
535+
536+
method_type = _get_callable_from_decorator(symbol.node)
537+
if method_type is None:
538+
return None
539+
540+
return _bind_method(method_type, symbol.node)
541+
542+
# Handle module-level function calls
543+
if isinstance(callee, NameExpr) and callee.name:
544+
if callee.name not in ctx.api.globals:
545+
return None
546+
547+
symbol = ctx.api.globals[callee.name]
548+
if not isinstance(symbol.node, Decorator):
549+
return None
550+
551+
return _get_callable_from_decorator(symbol.node)
552+
553+
return None

0 commit comments

Comments
 (0)