Skip to content

Support non-trivial meets and joins of callables #18647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,6 +2322,7 @@ def type_checker(self) -> TypeChecker:
manager.plugin,
self.per_line_checking_time_ns,
)
type_state.object_type = self._type_checker.named_type("builtins.object")
return self._type_checker

def type_map(self) -> dict[Expression, Type]:
Expand Down
214 changes: 130 additions & 84 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import overload
from typing import Callable, overload

import mypy.typeops
from mypy.expandtype import expand_type
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, ArgKind
from mypy.state import state
from mypy.subtypes import (
SubtypeContext,
Expand Down Expand Up @@ -52,6 +52,7 @@
get_proper_types,
split_with_prefix_and_suffix,
)
from mypy.typestate import type_state


class InstanceJoiner:
Expand Down Expand Up @@ -306,17 +307,7 @@ def visit_unpack_type(self, t: UnpackType) -> UnpackType:

def visit_parameters(self, t: Parameters) -> ProperType:
if isinstance(self.s, Parameters):
if not is_similar_params(t, self.s):
# TODO: it would be prudent to return [*object, **object] instead of Any.
return self.default(self.s)
from mypy.meet import meet_types

return t.copy_modified(
arg_types=[
meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)
],
arg_names=combine_arg_names(self.s, t),
)
return join_parameters(self.s, t) or self.default(self.s)
else:
return self.default(self.s)

Expand Down Expand Up @@ -354,10 +345,12 @@ def visit_instance(self, t: Instance) -> ProperType:
return self.default(self.s)

def visit_callable_type(self, t: CallableType) -> ProperType:
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
if isinstance(self.s, CallableType):
if is_equivalent(t, self.s):
return combine_similar_callables(t, self.s)
result = join_similar_callables(t, self.s)
if result is None:
return join_types(t.fallback, self.s)
# We set the from_type_type flag to suppress error when a collection of
# concrete class objects gets inferred as their common abstract superclass.
if not (
Expand Down Expand Up @@ -416,11 +409,12 @@ def visit_overloaded(self, t: Overloaded) -> ProperType:
# The interesting case where both types are function types.
for t_item in t.items:
for s_item in s.items:
if is_similar_callables(t_item, s_item):
if is_equivalent(t_item, s_item):
result.append(combine_similar_callables(t_item, s_item))
elif is_subtype(t_item, s_item):
result.append(s_item)
if is_equivalent(t_item, s_item):
result.append(combine_similar_callables(t_item, s_item))
elif is_subtype(t_item, s_item):
result.append(s_item)
elif (true_join := join_similar_callables(s_item, t_item)) is not None:
result.append(true_join)
if result:
# TODO: Simplify redundancies from the result.
if len(result) == 1:
Expand Down Expand Up @@ -638,6 +632,8 @@ def default(self, typ: Type) -> ProperType:
return self.default(typ.upper_bound)
elif isinstance(typ, ParamSpecType):
return self.default(typ.upper_bound)
elif type_state.object_type is not None:
return type_state.object_type
else:
return AnyType(TypeOfAny.special_form)

Expand Down Expand Up @@ -665,26 +661,6 @@ def normalize_callables(s: ProperType, t: ProperType) -> tuple[ProperType, Prope
return s, t


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
"""
return (
len(t.arg_types) == len(s.arg_types)
and t.min_args == s.min_args
and t.is_var_arg == s.is_var_arg
)


def is_similar_params(t: Parameters, s: Parameters) -> bool:
# This matches the logic in is_similar_callables() above.
return (
len(t.arg_types) == len(s.arg_types)
and t.min_args == s.min_args
and (t.var_arg() is not None) == (s.var_arg() is not None)
)


def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
tv_map = {}
tvs = []
Expand Down Expand Up @@ -712,21 +688,112 @@ def match_generic_callables(t: CallableType, s: CallableType) -> tuple[CallableT
return update_callable_ids(t, new_ids), update_callable_ids(s, new_ids)


def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
def join_parameters(s: Parameters, t: Parameters) -> Parameters | None:
from mypy.meet import meet_types

return combine_parameters_with(s, t, arg_transformer=meet_types, allow_uninhabited=False)


def combine_parameters_with(
s: Parameters,
t: Parameters,
arg_transformer: Callable[[Type, Type], Type],
allow_uninhabited: bool,
) -> Parameters | None:
if sum(k.is_required() for k in s.arg_kinds) < sum(k.is_required() for k in t.arg_kinds):
return join_parameters(t, s)

args_meet = []
arg_names: list[str | None] = []
arg_kinds: list[ArgKind] = []
vararg = None
kwarg = None
if (s_var := s.var_arg()) is not None and (t_var := t.var_arg()) is not None:
vararg = arg_transformer(s_var.typ, t_var.typ)
if isinstance(get_proper_type(vararg), UninhabitedType):
vararg = None
if (s_kw := s.kw_arg()) is not None and (t_kw := t.kw_arg()) is not None:
kwarg = arg_transformer(s_kw.typ, t_kw.typ)
if isinstance(get_proper_type(kwarg), UninhabitedType):
kwarg = None
spent_names: set[str] = set()
for s_kind, s_a in zip(s.arg_kinds, s.formal_arguments(include_star_args=True)):
if vararg is not None and s_kind == ArgKind.ARG_STAR:
args_meet.append(vararg)
arg_names.append(None)
arg_kinds.append(s_kind)
vararg = None
continue
if kwarg is not None and s_kind == ArgKind.ARG_STAR2:
args_meet.append(kwarg)
arg_names.append(None)
arg_kinds.append(s_kind)
kwarg = None
continue
if s_kind.is_star():
continue

raw_candidates = [t.argument_by_position(s_a.pos)]
if s_a.name is not None and s_a.name not in spent_names:
raw_candidates.append(t.argument_by_name(s_a.name))
candidates = [c for c in raw_candidates if c is not None]
if not candidates:
if s_a.required:
return None
continue

for t_a in candidates:
typ = arg_transformer(s_a.typ, t_a.typ)
if not isinstance(get_proper_type(typ), UninhabitedType):
break
else:
if not s_a.required:
continue
if not allow_uninhabited:
return None
if t_a.name is not None:
spent_names.add(t_a.name)
args_meet.append(typ)
arg_names.append(s_a.name if s_a.name == t_a.name else None)
kinds = [ArgKind.ARG_OPT, ArgKind.ARG_POS, ArgKind.ARG_NAMED_OPT, ArgKind.ARG_NAMED]
if s_a.pos != t_a.pos or s_a.pos is None or t_a.pos is None:
kinds = [k for k in kinds if not k.is_positional()]
if s_a.name != t_a.name or s_a.name is None or t_a.name is None:
kinds = [k for k in kinds if not k.is_named()]
if s_a.required or t_a.required:
kinds = [k for k in kinds if k.is_required()]
arg_kinds.append(kinds[0])
return t.copy_modified(arg_types=args_meet, arg_names=arg_names, arg_kinds=arg_kinds)


def join_similar_callables(t: CallableType, s: CallableType) -> CallableType | None:
if s.param_spec() != t.param_spec():
return None

t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_meet(t.arg_types[i], s.arg_types[i]))

joined_params = join_parameters(
Parameters(t.arg_types, t.arg_kinds, t.arg_names),
Parameters(s.arg_types, s.arg_kinds, s.arg_names),
)
if joined_params is None:
return None

# TODO in combine_similar_callables also applies here (names and kinds; user metaclasses)
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
# The result should always use 'function' as a fallback if either operands are using it.
fallback: ProperType
if t.fallback.type.fullname == "builtins.function":
fallback = t.fallback
else:
elif s.fallback.type.fullname == "builtins.function":
fallback = s.fallback
else:
fallback = join_types(s.fallback, t.fallback)
assert isinstance(fallback, Instance)
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
arg_types=joined_params.arg_types,
arg_names=joined_params.arg_names,
arg_kinds=joined_params.arg_kinds,
ret_type=join_types(t.ret_type, s.ret_type),
fallback=fallback,
name=None,
Expand Down Expand Up @@ -767,57 +834,36 @@ def safe_meet(t: Type, s: Type) -> Type:

def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
# TODO kinds and argument names

joined_params = combine_parameters_with(
Parameters(t.arg_types, t.arg_kinds, t.arg_names),
Parameters(s.arg_types, s.arg_kinds, s.arg_names),
arg_transformer=safe_join,
allow_uninhabited=True,
)
assert joined_params is not None

# TODO what should happen if one fallback is 'type' and the other is a user-provided metaclass?
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
# The result should always use 'function' as a fallback if either operands are using it.
fallback: ProperType
if t.fallback.type.fullname == "builtins.function":
fallback = t.fallback
else:
elif s.fallback.type.fullname == "builtins.function":
fallback = s.fallback
else:
fallback = join_types(s.fallback, t.fallback)
assert isinstance(fallback, Instance)
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
arg_types=joined_params.arg_types,
arg_names=joined_params.arg_names,
arg_kinds=joined_params.arg_kinds,
ret_type=join_types(t.ret_type, s.ret_type),
fallback=fallback,
name=None,
)


def combine_arg_names(
t: CallableType | Parameters, s: CallableType | Parameters
) -> list[str | None]:
"""Produces a list of argument names compatible with both callables.

For example, suppose 't' and 's' have the following signatures:

- t: (a: int, b: str, X: str) -> None
- s: (a: int, b: str, Y: str) -> None

This function would return ["a", "b", None]. This information
is then used above to compute the join of t and s, which results
in a signature of (a: int, b: str, str) -> None.

Note that the third argument's name is omitted and 't' and 's'
are both valid subtypes of this inferred signature.

Precondition: is_similar_types(t, s) is true.
"""
num_args = len(t.arg_types)
new_names = []
for i in range(num_args):
t_name = t.arg_names[i]
s_name = s.arg_names[i]
if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
new_names.append(t_name)
else:
new_names.append(None)
return new_names


def object_from_instance(instance: Instance) -> Instance:
"""Construct the type 'builtins.object' from an instance type."""
# Use the fact that 'object' is always the last class in the mro.
Expand Down
43 changes: 33 additions & 10 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,14 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
return self.default(self.s)

def visit_unpack_type(self, t: UnpackType) -> ProperType:
raise NotImplementedError
if isinstance(self.s, UnpackType):
res = UnpackType(
meet_types(self.s.type, t.type), from_star_syntax=self.s.from_star_syntax
)
res.set_line(self.s)
return res
else:
return self.default(self.s)

def visit_parameters(self, t: Parameters) -> ProperType:
if isinstance(self.s, Parameters):
Expand Down Expand Up @@ -889,10 +896,12 @@ def visit_instance(self, t: Instance) -> ProperType:
return self.default(self.s)

def visit_callable_type(self, t: CallableType) -> ProperType:
if isinstance(self.s, CallableType) and join.is_similar_callables(t, self.s):
if isinstance(self.s, CallableType):
if is_equivalent(t, self.s):
return join.combine_similar_callables(t, self.s)
result = meet_similar_callables(t, self.s)
if result is None:
return self.default(self.s)
# We set the from_type_type flag to suppress error when a collection of
# concrete class objects gets inferred as their common abstract superclass.
if not (
Expand Down Expand Up @@ -1099,22 +1108,36 @@ def default(self, typ: Type) -> ProperType:
return NoneType()


def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType:
from mypy.join import match_generic_callables, safe_join
def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType | None:
from mypy.join import combine_parameters_with, match_generic_callables, safe_join

if s.param_spec() != t.param_spec():
return None

t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
joined_params = combine_parameters_with(
Parameters(t.arg_types, t.arg_kinds, t.arg_names),
Parameters(s.arg_types, s.arg_kinds, s.arg_names),
arg_transformer=safe_join,
allow_uninhabited=True,
)
if joined_params is None:
return None
# TODO in combine_similar_callables also applies here (names and kinds)
# The fallback type can be either 'function' or 'type'. The result should have 'function' as
# fallback only if both operands have it as 'function'.
if t.fallback.type.fullname != "builtins.function":
fallback: ProperType
if t.fallback.type.fullname == "builtins.function":
fallback = s.fallback
elif s.fallback.type.fullname == "builtins.function":
fallback = t.fallback
else:
fallback = s.fallback
fallback = meet_types(s.fallback, t.fallback)
assert isinstance(fallback, Instance)
return t.copy_modified(
arg_types=arg_types,
arg_types=joined_params.arg_types,
arg_names=joined_params.arg_names,
arg_kinds=joined_params.arg_kinds,
ret_type=meet_types(t.ret_type, s.ret_type),
fallback=fallback,
name=None,
Expand Down
Loading
Loading