Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
)
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
from mypy.plugins.functools import (
functools_lru_cache_callback,
functools_total_ordering_maker_callback,
functools_total_ordering_makers,
lru_cache_wrapper_call_callback,
partial_call_callback,
partial_new_callback,
)
Expand Down Expand Up @@ -103,6 +105,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return create_singledispatch_function_callback
elif fullname == "functools.partial":
return partial_new_callback
elif fullname == "functools.lru_cache":
return functools_lru_cache_callback
elif fullname == "enum.member":
return enum_member_callback
return None
Expand Down Expand Up @@ -162,6 +166,8 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return call_singledispatch_function_after_register_argument
elif fullname == "functools.partial.__call__":
return partial_call_callback
elif fullname == "functools._lru_cache_wrapper.__call__":
return lru_cache_wrapper_call_callback
return None

def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
Expand Down
158 changes: 158 additions & 0 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
ArgKind,
Argument,
CallExpr,
Decorator,
MemberExpr,
NameExpr,
Var,
)
Expand All @@ -25,6 +27,8 @@
AnyType,
CallableType,
Instance,
LiteralType,
NoneType,
Overloaded,
ParamSpecFlavor,
ParamSpecType,
Expand All @@ -41,6 +45,7 @@
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}

PARTIAL: Final = "functools.partial"
LRU_CACHE: Final = "functools.lru_cache"


class _MethodInfo(NamedTuple):
Expand Down Expand Up @@ -393,3 +398,156 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)

return result


def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.lru_cache decorator"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type

# Only handle the very specific case: @lru_cache (without parentheses)
# where a single function is passed directly as the only argument
if (
len(ctx.arg_types) == 1
and len(ctx.arg_types[0]) == 1
and len(ctx.args) == 1
and len(ctx.args[0]) == 1
):

first_arg_type = ctx.arg_types[0][0]

proper_first_arg_type = get_proper_type(first_arg_type)
if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)):
return ctx.default_return_type

# Try to extract callable type
fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type)
if fn_type is not None:
# This is the @lru_cache case (function passed directly)
return fn_type

# For all other cases (parameterized, multiple args, etc.), don't interfere
return ctx.default_return_type


def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
if not isinstance(ctx.api, mypy.checker.TypeChecker):
return ctx.default_return_type

# Safety check: ensure we have the required context
if not ctx.context or not ctx.args or not ctx.arg_types:
return ctx.default_return_type

# Try to find the original function signature using AST/symbol table analysis
original_signature = _find_original_function_signature(ctx)

if original_signature is not None:
# Validate the call against the original function signature
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
seen_args = set()

for i, param in enumerate(ctx.args):
for j, a in enumerate(param):
if a in seen_args:
continue
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])

# Check the call against the original signature
result, _ = ctx.api.expr_checker.check_call(
callee=original_signature,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
return result

return ctx.default_return_type


def _get_callable_from_decorator(decorator_node: Decorator) -> CallableType | None:
"""Extract the CallableType from a Decorator node if available."""
func_def = decorator_node.func
if isinstance(func_def.type, CallableType):
return func_def.type
return None


def _bind_method(method_type: CallableType, decorator_node: Decorator) -> CallableType:
"""
Bind a method by removing the self parameter for instance methods.

Static and class methods are returned unchanged.
"""
func_def = decorator_node.func

# For instance methods, bind self by removing the first parameter
if not func_def.is_static and not func_def.is_class and method_type.arg_types:
return method_type.copy_modified(
arg_types=method_type.arg_types[1:],
arg_kinds=method_type.arg_kinds[1:],
arg_names=method_type.arg_names[1:],
)
return method_type


def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None:
"""
Find the original function signature from an lru_cache decorated function call.

Returns the CallableType of the original function if found, None otherwise.
"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return None

if not isinstance(ctx.context, CallExpr):
return None

callee = ctx.context.callee

# Handle method calls (obj.method() or Class.method())
if isinstance(callee, MemberExpr):
method_name = callee.name
if not method_name:
return None

# Get the type of the object or class being accessed
member_type = ctx.api.expr_checker.accept(callee.expr)
proper_type = get_proper_type(member_type)

if not isinstance(proper_type, Instance):
return None

# Look up the method in the class
class_info = proper_type.type
if method_name not in class_info.names:
return None

symbol = class_info.names[method_name]
if not isinstance(symbol.node, Decorator):
return None

method_type = _get_callable_from_decorator(symbol.node)
if method_type is None:
return None

return _bind_method(method_type, symbol.node)

# Handle module-level function calls
if isinstance(callee, NameExpr) and callee.name:
if callee.name not in ctx.api.globals:
return None

symbol = ctx.api.globals[callee.name]
if not isinstance(symbol.node, Decorator):
return None

return _get_callable_from_decorator(symbol.node)

return None
Loading