Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 105 additions & 137 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@
fixup_partial_type,
function_type,
is_literal_type_like,
is_singleton_type,
is_singleton_equality_type,
is_singleton_identity_type,
make_simplified_union,
true_only,
try_expanding_sum_type_to_union,
Expand Down Expand Up @@ -6676,30 +6677,57 @@ def narrow_type_by_equality(
expr_indices: list[int],
narrowable_indices: AbstractSet[int],
) -> tuple[TypeMap, TypeMap]:
"""Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks."""
# is_valid_target:
# Controls which types we're allowed to narrow exprs to. Note that
# we cannot use 'is_literal_type_like' in both cases since doing
# 'x = 10000 + 1; x is 10001' is not always True in all Python
# implementations.
#
# coerce_only_in_literal_context:
# If true, coerce types into literal types only if one or more of
# the provided exprs contains an explicit Literal type. This could
# technically be set to any arbitrary value, but it seems being liberal
# with narrowing when using 'is' and conservative when using '==' seems
# to break the least amount of real-world code.
#
"""
Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks.

The 'operands' and 'operand_types' lists should be the full list of operands used
in the overall comparison expression. The 'chain_indices' list is the list of indices
actually used within this identity comparison chain.

So if we have the expression:

a <= b is c is d <= e

...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
would be the list [1, 2, 3].

The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
to refine the types of: that is, all operands that will potentially be a part of
the output TypeMaps.

"""
# should_narrow_by_identity_equality:
# Set to 'false' only if the user defines custom __eq__ or __ne__ methods
# that could cause identity-based narrowing to produce invalid results.
# If operator is "==" or "!=", we cannot narrow if we detect the presence of a user defined
# custom __eq__ or __ne__ method
should_narrow_by_identity_equality: bool

# is_target_for_value_narrowing:
# If the operator returns True when compared to this target, do we narrow in else branch?
# E.g. if operator is "==", then:
# - is_target_for_value_narrowing(str) == False
# - is_target_for_value_narrowing(Literal["asdf"]) == True
is_target_for_value_narrowing: Callable[[ProperType], bool]

# should_coerce_literals:
# Ideally, we should always attempt to have this set to True. Unfortunately, for now,
# performing this coercion can sometimes result in overly aggressive narrowing when taking
# in the context of other type checker behaviour.
should_coerce_literals: bool

if operator in {"is", "is not"}:
is_valid_target: Callable[[Type], bool] = is_singleton_type
coerce_only_in_literal_context = False
is_target_for_value_narrowing = is_singleton_identity_type
should_coerce_literals = True
should_narrow_by_identity_equality = True

elif operator in {"==", "!="}:
is_valid_target = is_singleton_value
coerce_only_in_literal_context = True
is_target_for_value_narrowing = is_singleton_equality_type

should_coerce_literals = False
for i in expr_indices:
typ = get_proper_type(operand_types[i])
if is_literal_type_like(typ) or (isinstance(typ, Instance) and typ.type.is_enum):
should_coerce_literals = True
break

expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity_equality = not any(
Expand All @@ -6708,21 +6736,63 @@ def narrow_type_by_equality(
else:
raise AssertionError

if should_narrow_by_identity_equality:
return self.narrow_identity_equality_comparison(
operands,
operand_types,
expr_indices,
narrowable_indices,
is_valid_target,
coerce_only_in_literal_context,
if not should_narrow_by_identity_equality:
# This is a bit of a legacy code path that might be a little unsound since it ignores
# custom __eq__. We should see if we can get rid of it in favour of `return {}, {}`
return self.refine_away_none_in_comparison(
operands, operand_types, expr_indices, narrowable_indices
)

# This is a bit of a legacy code path that might be a little unsound since it ignores
# custom __eq__. We should see if we can get rid of it.
return self.refine_away_none_in_comparison(
operands, operand_types, expr_indices, narrowable_indices
)
value_targets = []
type_targets = []
for i in expr_indices:
expr_type = operand_types[i]
if should_coerce_literals:
# TODO: doing this prevents narrowing a single-member Enum to literal
# of its member, because we expand it here and then refuse to add equal
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
# See testMatchEnumSingleChoice
expr_type = coerce_to_literal(expr_type)
if is_target_for_value_narrowing(get_proper_type(expr_type)):
value_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
else:
type_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))

partial_type_maps = []

if value_targets:
for i in expr_indices:
if i not in narrowable_indices:
continue
for j, target in value_targets:
if i == j:
continue
expr_type = coerce_to_literal(operand_types[i])
expr_type = try_expanding_sum_type_to_union(expr_type, None)
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
partial_type_maps.append((if_map, else_map))

if type_targets:
for i in expr_indices:
if i not in narrowable_indices:
continue
for j, target in type_targets:
if i == j:
continue
expr_type = operand_types[i]
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
if if_map:
else_map = {} # this is the big difference compared to the above
partial_type_maps.append((if_map, else_map))

# We will not have duplicate entries in our type maps if we only have two operands,
# so we can skip running meets on the intersections
return reduce_conditional_maps(partial_type_maps, use_meet=len(operands) > 2)

def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap:
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
Expand Down Expand Up @@ -6905,103 +6975,6 @@ def _propagate_walrus_assignments(
return parent_expr
return expr

def narrow_identity_equality_comparison(
self,
operands: list[Expression],
operand_types: list[Type],
chain_indices: list[int],
narrowable_operand_indices: AbstractSet[int],
is_valid_target: Callable[[ProperType], bool],
coerce_only_in_literal_context: bool,
) -> tuple[TypeMap, TypeMap]:
"""Produce conditional type maps refining expressions by an identity/equality comparison.

The 'operands' and 'operand_types' lists should be the full list of operands used
in the overall comparison expression. The 'chain_indices' list is the list of indices
actually used within this identity comparison chain.

So if we have the expression:

a <= b is c is d <= e

...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
would be the list [1, 2, 3].

The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
to refine the types of: that is, all operands that will potentially be a part of
the output TypeMaps.

Although this function could theoretically try setting the types of the operands
in the chains to the meet, doing that causes too many issues in real-world code.
Instead, we use 'is_valid_target' to identify which of the given chain types
we could plausibly use as the refined type for the expressions in the chain.

Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing
expressions in the chain to a Literal type. Performing this coercion is sometimes
too aggressive of a narrowing, depending on context.
"""

if coerce_only_in_literal_context:
should_coerce = False
for i in chain_indices:
typ = get_proper_type(operand_types[i])
if is_literal_type_like(typ) or (isinstance(typ, Instance) and typ.type.is_enum):
should_coerce = True
break
else:
should_coerce = True

value_targets = []
type_targets = []
for i in chain_indices:
expr_type = operand_types[i]
if should_coerce:
# TODO: doing this prevents narrowing a single-member Enum to literal
# of its member, because we expand it here and then refuse to add equal
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
# See testMatchEnumSingleChoice
expr_type = coerce_to_literal(expr_type)
if is_valid_target(get_proper_type(expr_type)):
value_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
else:
type_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))

partial_type_maps = []

if value_targets:
for i in chain_indices:
if i not in narrowable_operand_indices:
continue
for j, target in value_targets:
if i == j:
continue
expr_type = coerce_to_literal(operand_types[i])
expr_type = try_expanding_sum_type_to_union(expr_type, None)
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
partial_type_maps.append((if_map, else_map))

if type_targets:
for i in chain_indices:
if i not in narrowable_operand_indices:
continue
for j, target in type_targets:
if i == j:
continue
expr_type = operand_types[i]
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
if if_map:
else_map = {}
partial_type_maps.append((if_map, else_map))

# We will not have duplicate entries in our type maps if we only have two operands,
# so we can skip running meets on the intersections
return reduce_conditional_maps(partial_type_maps, use_meet=len(operands) > 2)

def refine_away_none_in_comparison(
self,
operands: list[Expression],
Expand All @@ -7012,7 +6985,7 @@ def refine_away_none_in_comparison(
"""Produces conditional type maps refining away None in an identity/equality chain.

For more details about what the different arguments mean, see the
docstring of 'refine_identity_comparison_expression' up above.
docstring of 'narrow_type_by_equality' up above.
"""

non_optional_types = []
Expand Down Expand Up @@ -8596,11 +8569,6 @@ def reduce_conditional_maps(
return final_if_map, final_else_map


def is_singleton_value(t: Type) -> bool:
t = get_proper_type(t)
return isinstance(t, LiteralType) or t.is_singleton_type()


BUILTINS_CUSTOM_EQ_CHECKS: Final = {
"builtins.bytes",
"builtins.bytearray",
Expand Down
35 changes: 21 additions & 14 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from mypy.state import state
from mypy.types import (
ELLIPSIS_TYPE_NAMES,
AnyType,
CallableType,
ExtraAttrs,
Expand Down Expand Up @@ -985,24 +986,30 @@ def is_literal_type_like(t: Type | None) -> bool:
return False


def is_singleton_type(typ: Type) -> bool:
"""Returns 'true' if this type is a "singleton type" -- if there exists
exactly only one runtime value associated with this type.
def is_singleton_identity_type(typ: ProperType) -> bool:
"""
Returns True if every value of this type is identical to every other value of this type,
as judged by the `is` operator.

That is, given two values 'a' and 'b' that have the same type 't',
'is_singleton_type(t)' returns True if and only if the expression 'a is b' is
always true.
Note that this is not true of certain LiteralType, such as Literal[100001] or Literal["string"]
"""
if isinstance(typ, NoneType):
return True
if isinstance(typ, Instance):
return (typ.type.is_enum and len(typ.type.enum_members) == 1) or (
typ.type.fullname in ELLIPSIS_TYPE_NAMES
)
if isinstance(typ, LiteralType):
return typ.is_enum_literal() or isinstance(typ.value, bool)
return False

Currently, this returns True when given NoneTypes, enum LiteralTypes,
enum types with a single value and ... (Ellipses).

Note that other kinds of LiteralTypes cannot count as singleton types. For
example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed
that 'a is b' will always be true -- some implementations of Python will end up
constructing two distinct instances of 100001.
def is_singleton_equality_type(typ: ProperType) -> bool:
"""
typ = get_proper_type(typ)
return typ.is_singleton_type()
Returns True if every value of this type compares equal to every other value of this type,
as judged by the `==` operator.
"""
return isinstance(typ, LiteralType) or is_singleton_identity_type(typ)


def try_expanding_sum_type_to_union(typ: Type, target_fullname: str | None) -> Type:
Expand Down
18 changes: 0 additions & 18 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,6 @@ def write(self, data: WriteBuffer) -> None:
def read(cls, data: ReadBuffer) -> Type:
raise NotImplementedError(f"Cannot deserialize {cls.__name__} instance")

def is_singleton_type(self) -> bool:
return False


class TypeAliasType(Type):
"""A type alias to another type.
Expand Down Expand Up @@ -1479,9 +1476,6 @@ def read(cls, data: ReadBuffer) -> NoneType:
assert read_tag(data) == END_TAG
return NoneType()

def is_singleton_type(self) -> bool:
return True


# NoneType used to be called NoneTyp so to avoid needlessly breaking
# external plugins we keep that alias here.
Expand Down Expand Up @@ -1848,15 +1842,6 @@ def copy_with_extra_attr(self, name: str, typ: Type) -> Instance:
new.extra_attrs = existing_attrs
return new

def is_singleton_type(self) -> bool:
# TODO:
# Also make this return True if the type corresponds to NotImplemented?
return (
self.type.is_enum
and len(self.type.enum_members) == 1
or self.type.fullname in ELLIPSIS_TYPE_NAMES
)


class InstanceCache:
def __init__(self) -> None:
Expand Down Expand Up @@ -3332,9 +3317,6 @@ def read(cls, data: ReadBuffer) -> LiteralType:
assert read_tag(data) == END_TAG
return ret

def is_singleton_type(self) -> bool:
return self.is_enum_literal() or isinstance(self.value, bool)


class UnionType(ProperType):
"""The union type Union[T1, ..., Tn] (at least one type argument)."""
Expand Down