diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 47387530de30..d6073b86c7a8 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -50,8 +50,10 @@ ) from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback from mypy.plugins.functools import ( + functools_lru_cache_callback, functools_total_ordering_maker_callback, functools_total_ordering_makers, + lru_cache_wrapper_call_callback, partial_call_callback, partial_new_callback, ) @@ -103,6 +105,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] return create_singledispatch_function_callback elif fullname == "functools.partial": return partial_new_callback + elif fullname == "functools.lru_cache": + return functools_lru_cache_callback elif fullname == "enum.member": return enum_member_callback return None @@ -162,6 +166,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No return call_singledispatch_function_after_register_argument elif fullname == "functools.partial.__call__": return partial_call_callback + elif fullname == "functools._lru_cache_wrapper.__call__": + return lru_cache_wrapper_call_callback return None def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index c8b370f15e6d..50d4a2fd3cd1 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -16,6 +16,8 @@ ArgKind, Argument, CallExpr, + Decorator, + MemberExpr, NameExpr, Var, ) @@ -25,6 +27,8 @@ AnyType, CallableType, Instance, + LiteralType, + NoneType, Overloaded, ParamSpecFlavor, ParamSpecType, @@ -41,6 +45,7 @@ _ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"} PARTIAL: Final = "functools.partial" +LRU_CACHE: Final = "functools.lru_cache" class _MethodInfo(NamedTuple): @@ -393,3 +398,156 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) return result + + +def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """Infer a more precise return type for functools.lru_cache decorator""" + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals + return ctx.default_return_type + + # Only handle the very specific case: @lru_cache (without parentheses) + # where a single function is passed directly as the only argument + if ( + len(ctx.arg_types) == 1 + and len(ctx.arg_types[0]) == 1 + and len(ctx.args) == 1 + and len(ctx.args[0]) == 1 + ): + + first_arg_type = ctx.arg_types[0][0] + + proper_first_arg_type = get_proper_type(first_arg_type) + if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)): + return ctx.default_return_type + + # Try to extract callable type + fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type) + if fn_type is not None: + # This is the @lru_cache case (function passed directly) + return fn_type + + # For all other cases (parameterized, multiple args, etc.), don't interfere + return ctx.default_return_type + + +def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type: + """Handle calls to functools._lru_cache_wrapper objects to provide parameter validation""" + if not isinstance(ctx.api, mypy.checker.TypeChecker): + return ctx.default_return_type + + # Safety check: ensure we have the required context + if not ctx.context or not ctx.args or not ctx.arg_types: + return ctx.default_return_type + + # Try to find the original function signature using AST/symbol table analysis + original_signature = _find_original_function_signature(ctx) + + if original_signature is not None: + # Validate the call against the original function signature + actual_args = [] + actual_arg_kinds = [] + actual_arg_names = [] + seen_args = set() + + for i, param in enumerate(ctx.args): + for j, a in enumerate(param): + if a in seen_args: + continue + seen_args.add(a) + actual_args.append(a) + actual_arg_kinds.append(ctx.arg_kinds[i][j]) + actual_arg_names.append(ctx.arg_names[i][j]) + + # Check the call against the original signature + result, _ = ctx.api.expr_checker.check_call( + callee=original_signature, + args=actual_args, + arg_kinds=actual_arg_kinds, + arg_names=actual_arg_names, + context=ctx.context, + ) + return result + + return ctx.default_return_type + + +def _get_callable_from_decorator(decorator_node: Decorator) -> CallableType | None: + """Extract the CallableType from a Decorator node if available.""" + func_def = decorator_node.func + if isinstance(func_def.type, CallableType): + return func_def.type + return None + + +def _bind_method(method_type: CallableType, decorator_node: Decorator) -> CallableType: + """ + Bind a method by removing the self parameter for instance methods. + + Static and class methods are returned unchanged. + """ + func_def = decorator_node.func + + # For instance methods, bind self by removing the first parameter + if not func_def.is_static and not func_def.is_class and method_type.arg_types: + return method_type.copy_modified( + arg_types=method_type.arg_types[1:], + arg_kinds=method_type.arg_kinds[1:], + arg_names=method_type.arg_names[1:], + ) + return method_type + + +def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None: + """ + Find the original function signature from an lru_cache decorated function call. + + Returns the CallableType of the original function if found, None otherwise. + """ + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals + return None + + if not isinstance(ctx.context, CallExpr): + return None + + callee = ctx.context.callee + + # Handle method calls (obj.method() or Class.method()) + if isinstance(callee, MemberExpr): + method_name = callee.name + if not method_name: + return None + + # Get the type of the object or class being accessed + member_type = ctx.api.expr_checker.accept(callee.expr) + proper_type = get_proper_type(member_type) + + if not isinstance(proper_type, Instance): + return None + + # Look up the method in the class + class_info = proper_type.type + if method_name not in class_info.names: + return None + + symbol = class_info.names[method_name] + if not isinstance(symbol.node, Decorator): + return None + + method_type = _get_callable_from_decorator(symbol.node) + if method_type is None: + return None + + return _bind_method(method_type, symbol.node) + + # Handle module-level function calls + if isinstance(callee, NameExpr) and callee.name: + if callee.name not in ctx.api.globals: + return None + + symbol = ctx.api.globals[callee.name] + if not isinstance(symbol.node, Decorator): + return None + + return _get_callable_from_decorator(symbol.node) + + return None diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index ffd0a97b6988..7d7f89678177 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -726,3 +726,250 @@ def outer_c(arg: Tc) -> None: use_int_callable(partial(inner, b="")) # E: Argument 1 to "use_int_callable" has incompatible type "partial[str]"; expected "Callable[[int], int]" \ # N: "partial[str].__call__" has type "def __call__(__self, *args: Any, **kwargs: Any) -> str" [builtins fixtures/tuple.pyi] + +[case testLruCacheBasicValidation] +from functools import lru_cache + +@lru_cache +def f(v: str, at: int) -> str: + return v + +f() # E: Missing positional arguments "v", "at" in call to "f" +f("abc") # E: Missing positional argument "at" in call to "f" +f("abc", 123) # OK +f("abc", at=123) # OK +f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithReturnType] +from functools import lru_cache + +@lru_cache +def multiply(x: int, y: int) -> int: + return 42 + +reveal_type(multiply) # N: Revealed type is "def (x: builtins.int, y: builtins.int) -> builtins.int" +reveal_type(multiply(2, 3)) # N: Revealed type is "builtins.int" +multiply("a", 3) # E: Argument 1 to "multiply" has incompatible type "str"; expected "int" +multiply(2, "b") # E: Argument 2 to "multiply" has incompatible type "str"; expected "int" +multiply(2) # E: Missing positional argument "y" in call to "multiply" +multiply(1, 2, 3) # E: Too many arguments for "multiply" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithOptionalArgs] +from functools import lru_cache + +@lru_cache +def greet(name: str, greeting: str = "Hello") -> str: + return "result" + +greet("World") # OK +greet("World", "Hi") # OK +greet("World", greeting="Hi") # OK +greet() # E: Missing positional argument "name" in call to "greet" +greet(123) # E: Argument 1 to "greet" has incompatible type "int"; expected "str" +greet("World", 123) # E: Argument 2 to "greet" has incompatible type "int"; expected "str" +[builtins fixtures/dict.pyi] + +[case testLruCacheGenericFunction] +from functools import lru_cache +from typing import TypeVar + +T = TypeVar('T') + +@lru_cache +def identity(x: T) -> T: + return x + +reveal_type(identity(42)) # N: Revealed type is "builtins.int" +reveal_type(identity("hello")) # N: Revealed type is "builtins.str" +identity() # E: Missing positional argument "x" in call to "identity" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithParentheses] +from functools import lru_cache + +@lru_cache() +def f(v: str, at: int) -> str: + return v + +f() # E: Missing positional arguments "v", "at" in call to "f" +f("abc") # E: Missing positional argument "at" in call to "f" +f("abc", 123) # OK +f("abc", at=123) # OK +f("abc", at="wrong_type") # E: Argument "at" to "f" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheWithMaxsize] +from functools import lru_cache + +@lru_cache(maxsize=128) +def g(v: str, at: int) -> str: + return v + +g() # E: Missing positional arguments "v", "at" in call to "g" +g("abc") # E: Missing positional argument "at" in call to "g" +g("abc", 123) # OK +g("abc", at=123) # OK +g("abc", at="wrong_type") # E: Argument "at" to "g" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheGenericWithParameters] +from functools import lru_cache +from typing import TypeVar + +T = TypeVar('T') + +@lru_cache() +def identity_empty(x: T) -> T: + return x + +@lru_cache(maxsize=128) +def identity_maxsize(x: T) -> T: + return x + +reveal_type(identity_empty(42)) # N: Revealed type is "builtins.int" +reveal_type(identity_maxsize("hello")) # N: Revealed type is "builtins.str" +identity_empty() # E: Missing positional argument "x" in call to "identity_empty" +identity_maxsize() # E: Missing positional argument "x" in call to "identity_maxsize" +[builtins fixtures/dict.pyi] + +[case testLruCacheMaxsizeNone] +from functools import lru_cache + +@lru_cache(maxsize=None) +def unlimited_cache(x: int, y: str) -> str: + return y + +unlimited_cache(42, "test") # OK +unlimited_cache() # E: Missing positional arguments "x", "y" in call to "unlimited_cache" +unlimited_cache(42) # E: Missing positional argument "y" in call to "unlimited_cache" +unlimited_cache("wrong", "test") # E: Argument 1 to "unlimited_cache" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testLruCacheMaxsizeZero] +from functools import lru_cache + +@lru_cache(maxsize=0) +def no_cache(value: str) -> str: + return value + +no_cache("hello") # OK +no_cache() # E: Missing positional argument "value" in call to "no_cache" +no_cache(123) # E: Argument 1 to "no_cache" has incompatible type "int"; expected "str" +[builtins fixtures/dict.pyi] + +[case testLruCacheOnInstanceMethod] +from functools import lru_cache + +class MyClass: + @lru_cache + def method(self, x: int) -> str: + return "" + +obj = MyClass() +obj.method(1) # OK +obj.method("bad") # E: Argument 1 to "method" of "MyClass" has incompatible type "str"; expected "int" +obj.method() # E: Missing positional argument "x" in call to "method" of "MyClass" +reveal_type(obj.method(42)) # N: Revealed type is "builtins.str" +[builtins fixtures/dict.pyi] + +[case testLruCacheOnClassMethod] +from functools import lru_cache + +class MyClass: + @classmethod + @lru_cache + def method(cls, x: int) -> str: + return "" + +MyClass.method(1) # OK +MyClass.method("bad") # E: Argument 1 to "method" of "MyClass" has incompatible type "str"; expected "int" +MyClass.method() # E: Missing positional argument "x" in call to "method" of "MyClass" +[builtins fixtures/classmethod.pyi] + +[case testLruCacheOnStaticMethod] +from functools import lru_cache + +class MyClass: + @staticmethod + @lru_cache + def method(x: int) -> str: + return "" + +MyClass.method(1) # OK +MyClass.method("bad") # E: Argument 1 to "method" of "MyClass" has incompatible type "str"; expected "int" +obj = MyClass() +obj.method(1) # OK +obj.method("bad") # E: Argument 1 to "method" of "MyClass" has incompatible type "str"; expected "int" +[builtins fixtures/staticmethod.pyi] + +[case testLruCacheOnInstanceMethodWithParameters] +from functools import lru_cache + +class MyClass: + @lru_cache(maxsize=128) + def compute(self, x: int, y: str) -> str: + return "" + +obj = MyClass() +obj.compute(2, "hi") # OK +obj.compute("wrong", "hi") # E: Argument 1 to "compute" of "MyClass" has incompatible type "str"; expected "int" +obj.compute(2, 3) # E: Argument 2 to "compute" of "MyClass" has incompatible type "int"; expected "str" +[builtins fixtures/dict.pyi] + +[case testLruCacheNestedFunction] +from functools import lru_cache + +def outer(x: int) -> str: + @lru_cache + def inner(y: str) -> str: + return "" # Simplified + + result1 = inner("test") # OK + result2 = inner(123) # E: Argument 1 to "inner" has incompatible type "int"; expected "str" + return result1 +[builtins fixtures/dict.pyi] + +[case testLruCacheMultipleDecorators] +from functools import lru_cache + +class MyClass: + @property + @lru_cache + def prop(self) -> int: + return 42 + +obj = MyClass() +reveal_type(obj.prop) # N: Revealed type is "builtins.int" +[builtins fixtures/property.pyi] + + +[case testLruCacheOnInstanceMethodNoMaxSize] +from functools import lru_cache + +class MyClass: + @lru_cache(maxsize=None) + def compute(self, x: int, y: str) -> str: + return "" + +obj = MyClass() +obj.compute(2, "hi") # OK +obj.compute("wrong", "hi") # E: Argument 1 to "compute" of "MyClass" has incompatible type "str"; expected "int" +obj.compute(2, 3) # E: Argument 2 to "compute" of "MyClass" has incompatible type "int"; expected "str" +[builtins fixtures/dict.pyi] + + +[case testLruCacheOnInstanceTyped] +from functools import lru_cache + +class MyClass: + @lru_cache(typed=True) + def compute(self, x: int, y: str) -> str: + return "" + +obj = MyClass() +obj.compute(2, "hi") # OK +obj.compute("wrong", "hi") # E: Argument 1 to "compute" of "MyClass" has incompatible type "str"; expected "int" +obj.compute(2, 3) # E: Argument 2 to "compute" of "MyClass" has incompatible type "int"; expected "str" +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/lib-stub/functools.pyi b/test-data/unit/lib-stub/functools.pyi index b8a15fe60b74..029fbfa708e5 100644 --- a/test-data/unit/lib-stub/functools.pyi +++ b/test-data/unit/lib-stub/functools.pyi @@ -39,3 +39,11 @@ class partial(Generic[_T]): def __call__(__self, *args: Any, **kwargs: Any) -> _T: ... def wraps(wrapped: _T) -> partial[_T]: ... + +class _lru_cache_wrapper(Generic[_T]): + def __call__(__self, *args: Any, **kwargs: Any) -> _T: ... + +@overload +def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ... +@overload +def lru_cache(__func: Callable[..., _T]) -> _lru_cache_wrapper[_T]: ...