|
16 | 16 | ArgKind, |
17 | 17 | Argument, |
18 | 18 | CallExpr, |
| 19 | + Decorator, |
| 20 | + MemberExpr, |
19 | 21 | NameExpr, |
20 | 22 | Var, |
21 | 23 | ) |
|
25 | 27 | AnyType, |
26 | 28 | CallableType, |
27 | 29 | Instance, |
| 30 | + LiteralType, |
| 31 | + NoneType, |
28 | 32 | Overloaded, |
29 | 33 | ParamSpecFlavor, |
30 | 34 | ParamSpecType, |
|
41 | 45 | _ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"} |
42 | 46 |
|
43 | 47 | PARTIAL: Final = "functools.partial" |
| 48 | +LRU_CACHE: Final = "functools.lru_cache" |
44 | 49 |
|
45 | 50 |
|
46 | 51 | class _MethodInfo(NamedTuple): |
@@ -393,3 +398,156 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: |
393 | 398 | ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) |
394 | 399 |
|
395 | 400 | return result |
| 401 | + |
| 402 | + |
| 403 | +def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type: |
| 404 | + """Infer a more precise return type for functools.lru_cache decorator""" |
| 405 | + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals |
| 406 | + return ctx.default_return_type |
| 407 | + |
| 408 | + # Only handle the very specific case: @lru_cache (without parentheses) |
| 409 | + # where a single function is passed directly as the only argument |
| 410 | + if ( |
| 411 | + len(ctx.arg_types) == 1 |
| 412 | + and len(ctx.arg_types[0]) == 1 |
| 413 | + and len(ctx.args) == 1 |
| 414 | + and len(ctx.args[0]) == 1 |
| 415 | + ): |
| 416 | + |
| 417 | + first_arg_type = ctx.arg_types[0][0] |
| 418 | + |
| 419 | + proper_first_arg_type = get_proper_type(first_arg_type) |
| 420 | + if isinstance(proper_first_arg_type, (LiteralType, Instance, NoneType)): |
| 421 | + return ctx.default_return_type |
| 422 | + |
| 423 | + # Try to extract callable type |
| 424 | + fn_type = ctx.api.extract_callable_type(first_arg_type, ctx=ctx.default_return_type) |
| 425 | + if fn_type is not None: |
| 426 | + # This is the @lru_cache case (function passed directly) |
| 427 | + return fn_type |
| 428 | + |
| 429 | + # For all other cases (parameterized, multiple args, etc.), don't interfere |
| 430 | + return ctx.default_return_type |
| 431 | + |
| 432 | + |
| 433 | +def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type: |
| 434 | + """Handle calls to functools._lru_cache_wrapper objects to provide parameter validation""" |
| 435 | + if not isinstance(ctx.api, mypy.checker.TypeChecker): |
| 436 | + return ctx.default_return_type |
| 437 | + |
| 438 | + # Safety check: ensure we have the required context |
| 439 | + if not ctx.context or not ctx.args or not ctx.arg_types: |
| 440 | + return ctx.default_return_type |
| 441 | + |
| 442 | + # Try to find the original function signature using AST/symbol table analysis |
| 443 | + original_signature = _find_original_function_signature(ctx) |
| 444 | + |
| 445 | + if original_signature is not None: |
| 446 | + # Validate the call against the original function signature |
| 447 | + actual_args = [] |
| 448 | + actual_arg_kinds = [] |
| 449 | + actual_arg_names = [] |
| 450 | + seen_args = set() |
| 451 | + |
| 452 | + for i, param in enumerate(ctx.args): |
| 453 | + for j, a in enumerate(param): |
| 454 | + if a in seen_args: |
| 455 | + continue |
| 456 | + seen_args.add(a) |
| 457 | + actual_args.append(a) |
| 458 | + actual_arg_kinds.append(ctx.arg_kinds[i][j]) |
| 459 | + actual_arg_names.append(ctx.arg_names[i][j]) |
| 460 | + |
| 461 | + # Check the call against the original signature |
| 462 | + result, _ = ctx.api.expr_checker.check_call( |
| 463 | + callee=original_signature, |
| 464 | + args=actual_args, |
| 465 | + arg_kinds=actual_arg_kinds, |
| 466 | + arg_names=actual_arg_names, |
| 467 | + context=ctx.context, |
| 468 | + ) |
| 469 | + return result |
| 470 | + |
| 471 | + return ctx.default_return_type |
| 472 | + |
| 473 | + |
| 474 | +def _get_callable_from_decorator(decorator_node: Decorator) -> CallableType | None: |
| 475 | + """Extract the CallableType from a Decorator node if available.""" |
| 476 | + func_def = decorator_node.func |
| 477 | + if isinstance(func_def.type, CallableType): |
| 478 | + return func_def.type |
| 479 | + return None |
| 480 | + |
| 481 | + |
| 482 | +def _bind_method(method_type: CallableType, decorator_node: Decorator) -> CallableType: |
| 483 | + """ |
| 484 | + Bind a method by removing the self parameter for instance methods. |
| 485 | +
|
| 486 | + Static and class methods are returned unchanged. |
| 487 | + """ |
| 488 | + func_def = decorator_node.func |
| 489 | + |
| 490 | + # For instance methods, bind self by removing the first parameter |
| 491 | + if not func_def.is_static and not func_def.is_class and method_type.arg_types: |
| 492 | + return method_type.copy_modified( |
| 493 | + arg_types=method_type.arg_types[1:], |
| 494 | + arg_kinds=method_type.arg_kinds[1:], |
| 495 | + arg_names=method_type.arg_names[1:], |
| 496 | + ) |
| 497 | + return method_type |
| 498 | + |
| 499 | + |
| 500 | +def _find_original_function_signature(ctx: mypy.plugin.MethodContext) -> CallableType | None: |
| 501 | + """ |
| 502 | + Find the original function signature from an lru_cache decorated function call. |
| 503 | +
|
| 504 | + Returns the CallableType of the original function if found, None otherwise. |
| 505 | + """ |
| 506 | + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals |
| 507 | + return None |
| 508 | + |
| 509 | + if not isinstance(ctx.context, CallExpr): |
| 510 | + return None |
| 511 | + |
| 512 | + callee = ctx.context.callee |
| 513 | + |
| 514 | + # Handle method calls (obj.method() or Class.method()) |
| 515 | + if isinstance(callee, MemberExpr): |
| 516 | + method_name = callee.name |
| 517 | + if not method_name: |
| 518 | + return None |
| 519 | + |
| 520 | + # Get the type of the object or class being accessed |
| 521 | + member_type = ctx.api.expr_checker.accept(callee.expr) |
| 522 | + proper_type = get_proper_type(member_type) |
| 523 | + |
| 524 | + if not isinstance(proper_type, Instance): |
| 525 | + return None |
| 526 | + |
| 527 | + # Look up the method in the class |
| 528 | + class_info = proper_type.type |
| 529 | + if method_name not in class_info.names: |
| 530 | + return None |
| 531 | + |
| 532 | + symbol = class_info.names[method_name] |
| 533 | + if not isinstance(symbol.node, Decorator): |
| 534 | + return None |
| 535 | + |
| 536 | + method_type = _get_callable_from_decorator(symbol.node) |
| 537 | + if method_type is None: |
| 538 | + return None |
| 539 | + |
| 540 | + return _bind_method(method_type, symbol.node) |
| 541 | + |
| 542 | + # Handle module-level function calls |
| 543 | + if isinstance(callee, NameExpr) and callee.name: |
| 544 | + if callee.name not in ctx.api.globals: |
| 545 | + return None |
| 546 | + |
| 547 | + symbol = ctx.api.globals[callee.name] |
| 548 | + if not isinstance(symbol.node, Decorator): |
| 549 | + return None |
| 550 | + |
| 551 | + return _get_callable_from_decorator(symbol.node) |
| 552 | + |
| 553 | + return None |
0 commit comments