From d13592531418375f6e465c37a82892e4702df40d Mon Sep 17 00:00:00 2001 From: root Date: Sun, 1 Jun 2025 13:00:19 -0300 Subject: [PATCH 1/2] refactor: compact and optimize infer_overload_return_type while preserving behavior and comments --- mypy/checkexpr.py | 82 ++++++++++++++++++++--------------------------- 1 file changed, 34 insertions(+), 48 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ace8f09bee48..6750bf728ae7 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2876,69 +2876,55 @@ def infer_overload_return_type( Assumes all of the given targets have argument counts compatible with the caller. """ - matches: list[CallableType] = [] return_types: list[Type] = [] inferred_types: list[Type] = [] - args_contain_any = any(map(has_any_type, arg_types)) type_maps: list[dict[Expression, Type]] = [] + args_contain_any = any(map(has_any_type, arg_types)) for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map() as m: - ret_type, infer_type = self.check_call( - callee=typ, - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) - is_match = not w.has_new_errors() - if is_match: - # Return early if possible; otherwise record info, so we can - # check for ambiguity due to 'Any' below. - if not args_contain_any: - self.chk.store_types(m) - return ret_type, infer_type - p_infer_type = get_proper_type(infer_type) - if isinstance(p_infer_type, CallableType): - # Prefer inferred types if possible, this will avoid false triggers for - # Any-ambiguity caused by arguments with Any passed to generic overloads. - matches.append(p_infer_type) - else: - matches.append(typ) - return_types.append(ret_type) - inferred_types.append(infer_type) - type_maps.append(m) + with self.msg.filter_errors() as w, self.chk.local_type_map() as m: + ret_type, infer_type = self.check_call( + callee=typ, args=args, arg_kinds=arg_kinds, arg_names=arg_names, + context=context, callable_name=callable_name, object_type=object_type) + if w.has_new_errors(): continue + + # Return early if possible; otherwise record info, so we can + # check for ambiguity due to 'Any' below. + if not args_contain_any: + self.chk.store_types(m) + return ret_type, infer_type + + # Prefer inferred types if possible, this will avoid false triggers for + # Any-ambiguity caused by arguments with Any passed to generic overloads. + p = get_proper_type(infer_type) + matches.append(p if isinstance(p, CallableType) else typ) + return_types.append(ret_type) + inferred_types.append(infer_type) + type_maps.append(m) if not matches: return None - elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names): - # An argument of type or containing the type 'Any' caused ambiguity. - # We try returning a precise type if we can. If not, we give up and just return 'Any'. + + # An argument of type or containing the type 'Any' caused ambiguity. + # We try returning a precise type if we can. If not, we give up and just return 'Any'. + if any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names): if all_same_types(return_types): self.chk.store_types(type_maps[0]) return return_types[0], inferred_types[0] - elif all_same_types([erase_type(typ) for typ in return_types]): + erased = [erase_type(t) for t in return_types] + if all_same_types(cast(list[Type], erased)): self.chk.store_types(type_maps[0]) return erase_type(return_types[0]), erase_type(inferred_types[0]) - else: - return self.check_call( - callee=AnyType(TypeOfAny.special_form), - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) - else: - # Success! No ambiguity; return the first match. - self.chk.store_types(type_maps[0]) - return return_types[0], inferred_types[0] + return self.check_call( + callee=AnyType(TypeOfAny.special_form), args=args, arg_kinds=arg_kinds, + arg_names=arg_names, context=context, callable_name=callable_name, + object_type=object_type) + + # Success! No ambiguity; return the first match. + self.chk.store_types(type_maps[0]) + return return_types[0], inferred_types[0] def overload_erased_call_targets( self, From d1757535d02439c4f669d72d2cb3bfa67d2a45db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Jun 2025 16:09:56 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checkexpr.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 6750bf728ae7..71a905e80a8f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2886,9 +2886,16 @@ def infer_overload_return_type( assert self.msg is self.chk.msg with self.msg.filter_errors() as w, self.chk.local_type_map() as m: ret_type, infer_type = self.check_call( - callee=typ, args=args, arg_kinds=arg_kinds, arg_names=arg_names, - context=context, callable_name=callable_name, object_type=object_type) - if w.has_new_errors(): continue + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) + if w.has_new_errors(): + continue # Return early if possible; otherwise record info, so we can # check for ambiguity due to 'Any' below. @@ -2918,9 +2925,14 @@ def infer_overload_return_type( self.chk.store_types(type_maps[0]) return erase_type(return_types[0]), erase_type(inferred_types[0]) return self.check_call( - callee=AnyType(TypeOfAny.special_form), args=args, arg_kinds=arg_kinds, - arg_names=arg_names, context=context, callable_name=callable_name, - object_type=object_type) + callee=AnyType(TypeOfAny.special_form), + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) # Success! No ambiguity; return the first match. self.chk.store_types(type_maps[0])