Skip to content

Commit 007d6de

Browse files
authored
Do not treat match value patterns as isinstance checks (#20146)
Fixes #20358. As discovered by @randolf-scholz in #20142, `mypy` treats value patterns in match statement in a completely wrong way. From [PEP 622](https://peps.python.org/pep-0622/#literal-patterns): > Literal pattern uses equality with literal on the right hand side, so that in the above example number == 0 and then possibly number == 1, etc will be evaluated. Existing tests for the feature are invalid: for example, ```python foo: object match foo: case 1: reveal_type(foo) ``` must reveal `object`, and test `testMatchValuePatternNarrows` asserts the opposite. Here's a runtime example: ``` >>> class A: ... def __eq__(self,o): return True ... >>> match A(): ... case 1: ... print("eq") ... eq ``` I have updated the existing tests accordingly. The idea is that value patterns are essentially equivalent to `if foo == SomeValue` checks, not `isinstance` checks modelled by `conditional_types_wit_intersection`. The original implementation was introduced in #10191.
1 parent f0b0fe7 commit 007d6de

File tree

4 files changed

+91
-27
lines changed

4 files changed

+91
-27
lines changed

mypy/checker.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6654,6 +6654,28 @@ def equality_type_narrowing_helper(
66546654
narrowable_operand_index_to_hash: dict[int, tuple[Key, ...]],
66556655
) -> tuple[TypeMap, TypeMap]:
66566656
"""Calculate type maps for '==', '!=', 'is' or 'is not' expression."""
6657+
# If we haven't been able to narrow types yet, we might be dealing with a
6658+
# explicit type(x) == some_type check
6659+
if_map, else_map = self.narrow_type_by_equality(
6660+
operator,
6661+
operands,
6662+
operand_types,
6663+
expr_indices,
6664+
narrowable_operand_index_to_hash.keys(),
6665+
)
6666+
if if_map == {} and else_map == {} and node is not None:
6667+
if_map, else_map = self.find_type_equals_check(node, expr_indices)
6668+
return if_map, else_map
6669+
6670+
def narrow_type_by_equality(
6671+
self,
6672+
operator: str,
6673+
operands: list[Expression],
6674+
operand_types: list[Type],
6675+
expr_indices: list[int],
6676+
narrowable_indices: AbstractSet[int],
6677+
) -> tuple[TypeMap, TypeMap]:
6678+
"""Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks."""
66576679
# is_valid_target:
66586680
# Controls which types we're allowed to narrow exprs to. Note that
66596681
# we cannot use 'is_literal_type_like' in both cases since doing
@@ -6699,20 +6721,15 @@ def has_no_custom_eq_checks(t: Type) -> bool:
66996721
operands,
67006722
operand_types,
67016723
expr_indices,
6702-
narrowable_operand_index_to_hash.keys(),
6724+
narrowable_indices,
67036725
is_valid_target,
67046726
coerce_only_in_literal_context,
67056727
)
67066728

67076729
if if_map == {} and else_map == {}:
67086730
if_map, else_map = self.refine_away_none_in_comparison(
6709-
operands, operand_types, expr_indices, narrowable_operand_index_to_hash.keys()
6731+
operands, operand_types, expr_indices, narrowable_indices
67106732
)
6711-
6712-
# If we haven't been able to narrow types yet, we might be dealing with a
6713-
# explicit type(x) == some_type check
6714-
if if_map == {} and else_map == {}:
6715-
if_map, else_map = self.find_type_equals_check(node, expr_indices)
67166733
return if_map, else_map
67176734

67186735
def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap:
@@ -6947,6 +6964,11 @@ def should_coerce_inner(typ: Type) -> bool:
69476964
for i in chain_indices:
69486965
expr_type = operand_types[i]
69496966
if should_coerce:
6967+
# TODO: doing this prevents narrowing a single-member Enum to literal
6968+
# of its member, because we expand it here and then refuse to add equal
6969+
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6970+
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6971+
# See testMatchEnumSingleChoice
69506972
expr_type = coerce_to_literal(expr_type)
69516973
if not is_valid_target(get_proper_type(expr_type)):
69526974
continue

mypy/checker_shared.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from abc import abstractmethod
6-
from collections.abc import Iterator, Sequence
6+
from collections.abc import Iterator, Sequence, Set as AbstractSet
77
from contextlib import contextmanager
88
from typing import NamedTuple, overload
99

@@ -245,6 +245,17 @@ def conditional_types_with_intersection(
245245
) -> tuple[Type | None, Type | None]:
246246
raise NotImplementedError
247247

248+
@abstractmethod
249+
def narrow_type_by_equality(
250+
self,
251+
operator: str,
252+
operands: list[Expression],
253+
operand_types: list[Type],
254+
expr_indices: list[int],
255+
narrowable_indices: AbstractSet[int],
256+
) -> tuple[dict[Expression, Type] | None, dict[Expression, Type] | None]:
257+
raise NotImplementedError
258+
248259
@abstractmethod
249260
def check_deprecated(self, node: Node | None, context: Context) -> None:
250261
raise NotImplementedError

mypy/checkpattern.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mypy.maptype import map_instance_to_supertype
1515
from mypy.meet import narrow_declared_type
1616
from mypy.messages import MessageBuilder
17-
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, Var
17+
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TempNode, TypeAlias, Var
1818
from mypy.options import Options
1919
from mypy.patterns import (
2020
AsPattern,
@@ -39,7 +39,6 @@
3939
AnyType,
4040
FunctionLike,
4141
Instance,
42-
LiteralType,
4342
NoneType,
4443
ProperType,
4544
TupleType,
@@ -206,12 +205,15 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType:
206205
current_type = self.type_context[-1]
207206
typ = self.chk.expr_checker.accept(o.expr)
208207
typ = coerce_to_literal(typ)
209-
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
210-
current_type, [get_type_range(typ)], o, default=get_proper_type(typ)
208+
node = TempNode(current_type)
209+
# Value patterns are essentially a syntactic sugar on top of `if x == Value`.
210+
# They should be treated equivalently.
211+
ok_map, rest_map = self.chk.narrow_type_by_equality(
212+
"==", [node, TempNode(typ)], [current_type, typ], [0, 1], {0}
211213
)
212-
if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)):
213-
return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {})
214-
return PatternType(narrowed_type, rest_type, {})
214+
ok_type = ok_map.get(node, current_type) if ok_map is not None else UninhabitedType()
215+
rest_type = rest_map.get(node, current_type) if rest_map is not None else UninhabitedType()
216+
return PatternType(ok_type, rest_type, {})
215217

216218
def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
217219
current_type = self.type_context[-1]

test-data/unit/check-python310.test

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ m: Any
3030

3131
match m:
3232
case 1:
33-
reveal_type(m) # N: Revealed type is "Literal[1]"
33+
reveal_type(m) # N: Revealed type is "Any"
3434
case 2:
35-
reveal_type(m) # N: Revealed type is "Literal[2]"
35+
reveal_type(m) # N: Revealed type is "Any"
3636
case other:
3737
reveal_type(other) # N: Revealed type is "Any"
3838

@@ -61,7 +61,7 @@ m: object
6161

6262
match m:
6363
case b.b:
64-
reveal_type(m) # N: Revealed type is "builtins.int"
64+
reveal_type(m) # N: Revealed type is "builtins.object"
6565
[file b.py]
6666
b: int
6767

@@ -83,7 +83,7 @@ m: A
8383

8484
match m:
8585
case b.b:
86-
reveal_type(m) # N: Revealed type is "__main__.<subclass of "__main__.A" and "b.B">"
86+
reveal_type(m) # N: Revealed type is "__main__.A"
8787
[file b.py]
8888
class B: ...
8989
b: B
@@ -96,7 +96,7 @@ m: int
9696

9797
match m:
9898
case b.b:
99-
reveal_type(m)
99+
reveal_type(m) # N: Revealed type is "builtins.int"
100100
[file b.py]
101101
b: str
102102
[builtins fixtures/primitives.pyi]
@@ -1742,14 +1742,15 @@ from typing import NoReturn
17421742
def assert_never(x: NoReturn) -> None: ...
17431743

17441744
class Medal(Enum):
1745-
gold = 1
1745+
GOLD = 1
17461746

17471747
def f(m: Medal) -> None:
17481748
always_assigned: int | None = None
17491749
match m:
1750-
case Medal.gold:
1750+
case Medal.GOLD:
17511751
always_assigned = 1
1752-
reveal_type(m) # N: Revealed type is "Literal[__main__.Medal.gold]"
1752+
# This should narrow to literal, see TODO in checker::refine_identity_comparison_expression
1753+
reveal_type(m) # N: Revealed type is "__main__.Medal"
17531754
case _:
17541755
assert_never(m)
17551756

@@ -1785,6 +1786,34 @@ def g(m: Medal) -> int:
17851786
return 2
17861787
[builtins fixtures/enum.pyi]
17871788

1789+
[case testMatchLiteralOrValuePattern]
1790+
# flags: --warn-unreachable
1791+
from typing import Literal
1792+
1793+
def test1(x: Literal[1,2,3]) -> None:
1794+
match x:
1795+
case 1:
1796+
reveal_type(x) # N: Revealed type is "Literal[1]"
1797+
case other:
1798+
reveal_type(x) # N: Revealed type is "Union[Literal[2], Literal[3]]"
1799+
1800+
def test2(x: Literal[1,2,3]) -> None:
1801+
match x:
1802+
case 1:
1803+
reveal_type(x) # N: Revealed type is "Literal[1]"
1804+
case 2:
1805+
reveal_type(x) # N: Revealed type is "Literal[2]"
1806+
case 3:
1807+
reveal_type(x) # N: Revealed type is "Literal[3]"
1808+
case other:
1809+
1 # E: Statement is unreachable
1810+
1811+
def test3(x: Literal[1,2,3]) -> None:
1812+
match x:
1813+
case 1 | 3:
1814+
reveal_type(x) # N: Revealed type is "Union[Literal[1], Literal[3]]"
1815+
case other:
1816+
reveal_type(x) # N: Revealed type is "Literal[2]"
17881817

17891818
[case testMatchLiteralPatternEnumWithTypedAttribute]
17901819
from enum import Enum
@@ -2813,7 +2842,7 @@ match A().foo:
28132842
def int_literal() -> None:
28142843
match 12:
28152844
case 1 as s:
2816-
reveal_type(s) # N: Revealed type is "Literal[1]"
2845+
reveal_type(s) # E: Statement is unreachable
28172846
case int(i):
28182847
reveal_type(i) # N: Revealed type is "Literal[12]?"
28192848
case other:
@@ -2822,7 +2851,7 @@ def int_literal() -> None:
28222851
def str_literal() -> None:
28232852
match 'foo':
28242853
case 'a' as s:
2825-
reveal_type(s) # N: Revealed type is "Literal['a']"
2854+
reveal_type(s) # E: Statement is unreachable
28262855
case str(i):
28272856
reveal_type(i) # N: Revealed type is "Literal['foo']?"
28282857
case other:
@@ -2909,9 +2938,9 @@ T_Choice = TypeVar("T_Choice", bound=b.One | b.Two)
29092938
def switch(choice: type[T_Choice]) -> None:
29102939
match choice:
29112940
case b.One:
2912-
reveal_type(choice) # N: Revealed type is "def () -> b.One"
2941+
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"
29132942
case b.Two:
2914-
reveal_type(choice) # N: Revealed type is "def () -> b.Two"
2943+
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"
29152944
case _:
29162945
reveal_type(choice) # N: Revealed type is "type[T_Choice`-1]"
29172946

0 commit comments

Comments
 (0)