From 615ea268420e2f608c7bcdfc53fe899d1358b2cb Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 21 Oct 2024 01:42:06 +0200 Subject: [PATCH 1/8] Add support for literal addition --- mypy/checkexpr.py | 54 +++++++++++++ test-data/unit/check-literal.test | 100 +++++++++++++++++++++++-- test-data/unit/cmdline.test | 4 +- test-data/unit/fixtures/primitives.pyi | 1 + test-data/unit/typexport-basic.test | 2 +- 5 files changed, 150 insertions(+), 11 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 447afc16c464..5f645879b483 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3510,6 +3510,9 @@ def visit_op_expr(self, e: OpExpr) -> Type: items=proper_left_type.items + [UnpackType(mapped)] ) + if e.op == "+" and (result := self.literal_expression_addition(e, left_type)): + return result + use_reverse: UseReverse = USE_REVERSE_DEFAULT if e.op == "|": if is_named_instance(proper_left_type, "builtins.dict"): @@ -3570,6 +3573,57 @@ def visit_op_expr(self, e: OpExpr) -> Type: else: raise RuntimeError(f"Unknown operator {e.op}") + def literal_value_from_expr( + self, expr: Expression, typ: Type | None = None + ) -> tuple[list[str | int], str] | None: + if isinstance(expr, StrExpr): + return [expr.value], "builtins.str" + if isinstance(expr, IntExpr): + return [expr.value], "builtins.int" + if isinstance(expr, BytesExpr): + return [expr.value], "builtins.bytes" + + typ = typ or self.accept(expr) + ptype = get_proper_type(typ) + + if isinstance(ptype, LiteralType) and not isinstance(ptype.value, (bool, float)): + return [ptype.value], ptype.fallback.type.fullname + + if isinstance(ptype, UnionType): + fallback: str | None = None + values: list[str | int] = [] + for item in ptype.items: + pitem = get_proper_type(item) + if not isinstance(pitem, LiteralType) or isinstance(pitem.value, (float, bool)): + break + if fallback is None: + fallback = pitem.fallback.type.fullname + if fallback != pitem.fallback.type.fullname: + break + values.append(pitem.value) + else: + assert fallback is not None + return values, fallback + return None + + def literal_expression_addition(self, e: OpExpr, left_type: Type) -> Type | None: + """Check if literal values can be combined with addition.""" + assert e.op == "+" + if not (lvalue := self.literal_value_from_expr(e.left, left_type)): + return None + if not (rvalue := self.literal_value_from_expr(e.right)) or lvalue[1] != rvalue[1]: + return None + + values: list[int | str] = sorted( + { + val[0] + val[1] # type: ignore[operator] + for val in itertools.product(lvalue[0], rvalue[0]) + } + ) + if len(values) == 1: + return LiteralType(values[0], self.named_type(lvalue[1])) + return UnionType([LiteralType(val, self.named_type(lvalue[1])) for val in values]) + def visit_comparison_expr(self, e: ComparisonExpr) -> Type: """Type check a comparison expression. diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index d4774420ad89..2c4504a1c070 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1407,19 +1407,19 @@ c: Literal[4] d: Literal['foo'] e: str -reveal_type(a + a) # N: Revealed type is "builtins.int" +reveal_type(a + a) # N: Revealed type is "Literal[6]" reveal_type(a + b) # N: Revealed type is "builtins.int" reveal_type(b + a) # N: Revealed type is "builtins.int" -reveal_type(a + 1) # N: Revealed type is "builtins.int" -reveal_type(1 + a) # N: Revealed type is "builtins.int" -reveal_type(a + c) # N: Revealed type is "builtins.int" -reveal_type(c + a) # N: Revealed type is "builtins.int" +reveal_type(a + 1) # N: Revealed type is "Literal[4]" +reveal_type(1 + a) # N: Revealed type is "Literal[4]" +reveal_type(a + c) # N: Revealed type is "Literal[7]" +reveal_type(c + a) # N: Revealed type is "Literal[7]" -reveal_type(d + d) # N: Revealed type is "builtins.str" +reveal_type(d + d) # N: Revealed type is "Literal['foofoo']" reveal_type(d + e) # N: Revealed type is "builtins.str" reveal_type(e + d) # N: Revealed type is "builtins.str" -reveal_type(d + 'foo') # N: Revealed type is "builtins.str" -reveal_type('foo' + d) # N: Revealed type is "builtins.str" +reveal_type(d + 'bar') # N: Revealed type is "Literal['foobar']" +reveal_type('bar' + d) # N: Revealed type is "Literal['barfoo']" reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int" reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int" @@ -2998,3 +2998,87 @@ def check(obj: A[Literal[1]]) -> None: reveal_type(g('', obj)) # E: Cannot infer value of type parameter "T" of "g" \ # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] + +[case testLiteralAddition] +from typing import Union +from typing_extensions import Literal + +str_a: Literal["a"] +str_b: Literal["b"] +str_union_1: Literal["a", "b"] +str_union_2: Literal["c", "d"] +s: str +int_1: Literal[1] +int_2: Literal[2] +int_union_1: Literal[1, 2] +int_union_2: Literal[3, 4] +i: int +bytes_a: Literal[b"a"] +bytes_b: Literal[b"b"] +bytes_union_1: Literal[b"a", b"b"] +bytes_union_2: Literal[b"c", b"d"] +b: bytes + +misc_union: Literal["a", 1] + +reveal_type(str_a + str_b) # N: Revealed type is "Literal['ab']" +reveal_type(str_a + "b") # N: Revealed type is "Literal['ab']" +reveal_type("a" + str_b) # N: Revealed type is "Literal['ab']" +reveal_type(str_union_1 + "b") # N: Revealed type is "Literal['ab'] | Literal['bb']" +reveal_type(str_union_1 + str_b) # N: Revealed type is "Literal['ab'] | Literal['bb']" +reveal_type("a" + str_union_1) # N: Revealed type is "Literal['aa'] | Literal['ab']" +reveal_type(str_a + str_union_1) # N: Revealed type is "Literal['aa'] | Literal['ab']" +reveal_type(str_union_1 + str_union_2) # N: Revealed type is "Literal['ac'] | Literal['ad'] | Literal['bc'] | Literal['bd']" +reveal_type(str_a + s) # N: Revealed type is "builtins.str" +reveal_type(s + str_a) # N: Revealed type is "builtins.str" +reveal_type(str_union_1 + s) # N: Revealed type is "builtins.str" +reveal_type(s + str_union_1) # N: Revealed type is "builtins.str" + +reveal_type(int_1 + int_2) # N: Revealed type is "Literal[3]" +reveal_type(int_1 + 1) # N: Revealed type is "Literal[2]" +reveal_type(1 + int_1) # N: Revealed type is "Literal[2]" +reveal_type(int_union_1 + 1) # N: Revealed type is "Literal[2] | Literal[3]" +reveal_type(int_union_1 + int_1) # N: Revealed type is "Literal[2] | Literal[3]" +reveal_type(1 + int_union_1) # N: Revealed type is "Literal[2] | Literal[3]" +reveal_type(int_1 + int_union_1) # N: Revealed type is "Literal[2] | Literal[3]" +reveal_type(int_union_1 + int_union_2) # N: Revealed type is "Literal[4] | Literal[5] | Literal[6]" +reveal_type(int_1 + i) # N: Revealed type is "builtins.int" +reveal_type(i + int_1) # N: Revealed type is "builtins.int" +reveal_type(int_union_1 + i) # N: Revealed type is "builtins.int" +reveal_type(i + int_union_1) # N: Revealed type is "builtins.int" + +reveal_type(bytes_a + bytes_b) # N: Revealed type is "Literal[b'ab']" +reveal_type(bytes_a + b"b") # N: Revealed type is "Literal[b'ab']" +reveal_type(b"a" + bytes_b) # N: Revealed type is "Literal[b'ab']" +reveal_type(bytes_union_1 + b"b") # N: Revealed type is "Literal[b'ab'] | Literal[b'bb']" +reveal_type(bytes_union_1 + bytes_b) # N: Revealed type is "Literal[b'ab'] | Literal[b'bb']" +reveal_type(b"a" + bytes_union_1) # N: Revealed type is "Literal[b'aa'] | Literal[b'ab']" +reveal_type(bytes_a + bytes_union_1) # N: Revealed type is "Literal[b'aa'] | Literal[b'ab']" +reveal_type(bytes_union_1 + bytes_union_2) # N: Revealed type is "Literal[b'ac'] | Literal[b'ad'] | Literal[b'bc'] | Literal[b'bd']" +reveal_type(bytes_a + b) # N: Revealed type is "builtins.bytes" +reveal_type(b + bytes_a) # N: Revealed type is "builtins.bytes" +reveal_type(bytes_union_1 + b) # N: Revealed type is "builtins.bytes" +reveal_type(b + bytes_union_1) # N: Revealed type is "builtins.bytes" + +reveal_type(misc_union + "a") # N: Revealed type is "builtins.str | builtins.int" \ + # E: Unsupported operand types for + ("Literal[1]" and "str") \ + # N: Left operand is of type "Literal['a', 1]" +reveal_type("a" + misc_union) # E: Unsupported operand types for + ("str" and "Literal[1]") \ + # N: Right operand is of type "Literal['a', 1]" \ + # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] + +[case testLiteralAdditionTypedDict] +from typing import TypedDict +from typing_extensions import Literal + +class LookupDict(TypedDict): + top_var: str + bottom_var: str + var: str + +def func(d: LookupDict, pos: Literal["top_", "bottom_", ""]) -> str: + return d[pos + "var"] + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] diff --git a/test-data/unit/cmdline.test b/test-data/unit/cmdline.test index 85d9fc920d19..66abf5cea163 100644 --- a/test-data/unit/cmdline.test +++ b/test-data/unit/cmdline.test @@ -923,8 +923,8 @@ test_between(1 + 1) tabs.py:2: error: Incompatible return value type (got "None", expected "str") return None ^~~~ -tabs.py:4: error: Argument 1 to "test_between" has incompatible type "int"; -expected "str" +tabs.py:4: error: Argument 1 to "test_between" has incompatible type +"Literal[2]"; expected "str" test_between(1 + 1) ^~~~~~~~~~~~ diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 98e604e9e81e..8aa9ba9ec872 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -36,6 +36,7 @@ class str(Sequence[str]): def __getitem__(self, item: int) -> str: pass def format(self, *args: object, **kwargs: object) -> str: pass class bytes(Sequence[int]): + def __add__(self, x: bytes) -> bytes: pass def __iter__(self) -> Iterator[int]: pass def __contains__(self, other: object) -> bool: pass def __getitem__(self, item: int) -> int: pass diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index 77e7763824d6..a385f31926ba 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -142,7 +142,7 @@ class str: pass class list: pass class dict: pass [out] -OpExpr(3) : builtins.int +OpExpr(3) : Literal[3] OpExpr(4) : builtins.float OpExpr(5) : builtins.float OpExpr(6) : builtins.float From e4e90ba957fe06d017e933be6dad42ebac533f29 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 22 Oct 2024 00:08:37 +0200 Subject: [PATCH 2/8] Only infer LiteralType if one of thhe operands is a Literal --- mypy/checkexpr.py | 18 +++++++++++------- test-data/unit/check-literal.test | 3 +++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5f645879b483..91b03a2ae8d6 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3575,19 +3575,19 @@ def visit_op_expr(self, e: OpExpr) -> Type: def literal_value_from_expr( self, expr: Expression, typ: Type | None = None - ) -> tuple[list[str | int], str] | None: + ) -> tuple[list[str | int], str, bool] | None: if isinstance(expr, StrExpr): - return [expr.value], "builtins.str" + return [expr.value], "builtins.str", False if isinstance(expr, IntExpr): - return [expr.value], "builtins.int" + return [expr.value], "builtins.int", False if isinstance(expr, BytesExpr): - return [expr.value], "builtins.bytes" + return [expr.value], "builtins.bytes", False typ = typ or self.accept(expr) ptype = get_proper_type(typ) if isinstance(ptype, LiteralType) and not isinstance(ptype.value, (bool, float)): - return [ptype.value], ptype.fallback.type.fullname + return [ptype.value], ptype.fallback.type.fullname, True if isinstance(ptype, UnionType): fallback: str | None = None @@ -3603,7 +3603,7 @@ def literal_value_from_expr( values.append(pitem.value) else: assert fallback is not None - return values, fallback + return values, fallback, True return None def literal_expression_addition(self, e: OpExpr, left_type: Type) -> Type | None: @@ -3611,7 +3611,11 @@ def literal_expression_addition(self, e: OpExpr, left_type: Type) -> Type | None assert e.op == "+" if not (lvalue := self.literal_value_from_expr(e.left, left_type)): return None - if not (rvalue := self.literal_value_from_expr(e.right)) or lvalue[1] != rvalue[1]: + if ( + not (rvalue := self.literal_value_from_expr(e.right)) + or lvalue[1] != rvalue[1] # different fallback + or lvalue[2] + rvalue[2] == 0 # no LiteralType + ): return None values: list[int | str] = sorted( diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 2c4504a1c070..2210ddf95e9c 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -3021,6 +3021,7 @@ b: bytes misc_union: Literal["a", 1] +reveal_type("a" + "b") # N: Revealed type is "builtins.str" reveal_type(str_a + str_b) # N: Revealed type is "Literal['ab']" reveal_type(str_a + "b") # N: Revealed type is "Literal['ab']" reveal_type("a" + str_b) # N: Revealed type is "Literal['ab']" @@ -3034,6 +3035,7 @@ reveal_type(s + str_a) # N: Revealed type is "builtins.str" reveal_type(str_union_1 + s) # N: Revealed type is "builtins.str" reveal_type(s + str_union_1) # N: Revealed type is "builtins.str" +reveal_type(1 + 2) # N: Revealed type is "builtins.int" reveal_type(int_1 + int_2) # N: Revealed type is "Literal[3]" reveal_type(int_1 + 1) # N: Revealed type is "Literal[2]" reveal_type(1 + int_1) # N: Revealed type is "Literal[2]" @@ -3047,6 +3049,7 @@ reveal_type(i + int_1) # N: Revealed type is "builtins.int" reveal_type(int_union_1 + i) # N: Revealed type is "builtins.int" reveal_type(i + int_union_1) # N: Revealed type is "builtins.int" +reveal_type(b"a" + b"b") # N: Revealed type is "builtins.bytes" reveal_type(bytes_a + bytes_b) # N: Revealed type is "Literal[b'ab']" reveal_type(bytes_a + b"b") # N: Revealed type is "Literal[b'ab']" reveal_type(b"a" + bytes_b) # N: Revealed type is "Literal[b'ab']" From 5cabf62f2d58423a5231af3812264a3d6c05800c Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 22 Oct 2024 00:35:34 +0200 Subject: [PATCH 3/8] Fix tests --- test-data/unit/cmdline.test | 4 ++-- test-data/unit/typexport-basic.test | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test-data/unit/cmdline.test b/test-data/unit/cmdline.test index 66abf5cea163..85d9fc920d19 100644 --- a/test-data/unit/cmdline.test +++ b/test-data/unit/cmdline.test @@ -923,8 +923,8 @@ test_between(1 + 1) tabs.py:2: error: Incompatible return value type (got "None", expected "str") return None ^~~~ -tabs.py:4: error: Argument 1 to "test_between" has incompatible type -"Literal[2]"; expected "str" +tabs.py:4: error: Argument 1 to "test_between" has incompatible type "int"; +expected "str" test_between(1 + 1) ^~~~~~~~~~~~ diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index a385f31926ba..77e7763824d6 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -142,7 +142,7 @@ class str: pass class list: pass class dict: pass [out] -OpExpr(3) : Literal[3] +OpExpr(3) : builtins.int OpExpr(4) : builtins.float OpExpr(5) : builtins.float OpExpr(6) : builtins.float From 694211e34f2f37376e3143975fdcc1cc4eb7b7e8 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 22 Oct 2024 02:50:36 +0200 Subject: [PATCH 4/8] Minor improvements --- mypy/checkexpr.py | 4 +++- test-data/unit/check-literal.test | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 91b03a2ae8d6..9a95f1ca67d5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3626,7 +3626,9 @@ def literal_expression_addition(self, e: OpExpr, left_type: Type) -> Type | None ) if len(values) == 1: return LiteralType(values[0], self.named_type(lvalue[1])) - return UnionType([LiteralType(val, self.named_type(lvalue[1])) for val in values]) + return make_simplified_union( + [LiteralType(val, self.named_type(lvalue[1])) for val in values] + ) def visit_comparison_expr(self, e: ComparisonExpr) -> Type: """Type check a comparison expression. diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 2210ddf95e9c..804295fd3be4 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -3006,17 +3006,17 @@ from typing_extensions import Literal str_a: Literal["a"] str_b: Literal["b"] str_union_1: Literal["a", "b"] -str_union_2: Literal["c", "d"] +str_union_2: Literal["d", "c"] s: str int_1: Literal[1] int_2: Literal[2] int_union_1: Literal[1, 2] -int_union_2: Literal[3, 4] +int_union_2: Literal[4, 3] i: int bytes_a: Literal[b"a"] bytes_b: Literal[b"b"] bytes_union_1: Literal[b"a", b"b"] -bytes_union_2: Literal[b"c", b"d"] +bytes_union_2: Literal[b"d", b"c"] b: bytes misc_union: Literal["a", 1] From 295c7b1b421d8adf172fabba94887b3a4a4a3399 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:18:53 +0100 Subject: [PATCH 5/8] Move accept call to caller --- mypy/checkexpr.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9a95f1ca67d5..4274392aeadb 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3477,12 +3477,13 @@ def visit_op_expr(self, e: OpExpr) -> Type: if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) left_type = self.accept(e.left) - + right_type = self.accept(e.right) proper_left_type = get_proper_type(left_type) + proper_right_type = get_proper_type(right_type) + if isinstance(proper_left_type, TupleType) and e.op == "+": left_add_method = proper_left_type.partial_fallback.type.get("__add__") if left_add_method and left_add_method.fullname == "builtins.tuple.__add__": - proper_right_type = get_proper_type(self.accept(e.right)) if isinstance(proper_right_type, TupleType): right_radd_method = proper_right_type.partial_fallback.type.get("__radd__") if right_radd_method is None: @@ -3510,7 +3511,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: items=proper_left_type.items + [UnpackType(mapped)] ) - if e.op == "+" and (result := self.literal_expression_addition(e, left_type)): + if e.op == "+" and (result := self.literal_expression_addition(e, left_type, right_type)): return result use_reverse: UseReverse = USE_REVERSE_DEFAULT @@ -3519,14 +3520,12 @@ def visit_op_expr(self, e: OpExpr) -> Type: # This is a special case for `dict | TypedDict`. # 1. Find `dict | TypedDict` case # 2. Switch `dict.__or__` to `TypedDict.__ror__` (the same from both runtime and typing perspective) - proper_right_type = get_proper_type(self.accept(e.right)) if isinstance(proper_right_type, TypedDictType): use_reverse = USE_REVERSE_ALWAYS if isinstance(proper_left_type, TypedDictType): # This is the reverse case: `TypedDict | dict`, # simply do not allow the reverse checking: # do not call `__dict__.__ror__`. - proper_right_type = get_proper_type(self.accept(e.right)) if is_named_instance(proper_right_type, "builtins.dict"): use_reverse = USE_REVERSE_NEVER @@ -3537,7 +3536,6 @@ def visit_op_expr(self, e: OpExpr) -> Type: and isinstance(proper_left_type, Instance) and proper_left_type.type.fullname == "builtins.tuple" ): - proper_right_type = get_proper_type(self.accept(e.right)) if ( isinstance(proper_right_type, TupleType) and proper_right_type.partial_fallback.type.fullname == "builtins.tuple" @@ -3561,7 +3559,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: result, method_type = self.check_op( # The reverse operator here gives better error messages: operators.reverse_op_methods[method], - base_type=self.accept(e.right), + base_type=right_type, arg=e.left, context=e, allow_reverse=False, @@ -3574,7 +3572,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: raise RuntimeError(f"Unknown operator {e.op}") def literal_value_from_expr( - self, expr: Expression, typ: Type | None = None + self, expr: Expression, typ: Type ) -> tuple[list[str | int], str, bool] | None: if isinstance(expr, StrExpr): return [expr.value], "builtins.str", False @@ -3583,7 +3581,6 @@ def literal_value_from_expr( if isinstance(expr, BytesExpr): return [expr.value], "builtins.bytes", False - typ = typ or self.accept(expr) ptype = get_proper_type(typ) if isinstance(ptype, LiteralType) and not isinstance(ptype.value, (bool, float)): @@ -3601,18 +3598,20 @@ def literal_value_from_expr( if fallback != pitem.fallback.type.fullname: break values.append(pitem.value) - else: + else: # no break assert fallback is not None return values, fallback, True return None - def literal_expression_addition(self, e: OpExpr, left_type: Type) -> Type | None: + def literal_expression_addition( + self, e: OpExpr, left_type: Type, right_type: Type + ) -> Type | None: """Check if literal values can be combined with addition.""" assert e.op == "+" if not (lvalue := self.literal_value_from_expr(e.left, left_type)): return None if ( - not (rvalue := self.literal_value_from_expr(e.right)) + not (rvalue := self.literal_value_from_expr(e.right, right_type)) or lvalue[1] != rvalue[1] # different fallback or lvalue[2] + rvalue[2] == 0 # no LiteralType ): From 5d696776f3e871ec71f2f95c3c1641e7948d14da Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:19:36 +0100 Subject: [PATCH 6/8] Add additional test cases --- test-data/unit/check-literal.test | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 804295fd3be4..1aac63f3ea9d 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1423,6 +1423,7 @@ reveal_type('bar' + d) # N: Revealed type is "Literal['barfoo']" reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int" reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int" +reveal_type(a.__add__(a)) # N: Revealed type is "builtins.int" a *= b # E: Incompatible types in assignment (expression has type "int", variable has type "Literal[3]") b *= a @@ -3000,13 +3001,19 @@ def check(obj: A[Literal[1]]) -> None: [builtins fixtures/tuple.pyi] [case testLiteralAddition] -from typing import Union +from typing import Any, Union from typing_extensions import Literal +class A: + def __add__(self, other: str) -> str: ... + def __radd__(self, other: str) -> str: ... + str_a: Literal["a"] str_b: Literal["b"] str_union_1: Literal["a", "b"] str_union_2: Literal["d", "c"] +str_union_mixed_1: Union[Literal["a"], Any] +str_union_mixed_2: Union[Literal["a"], A] s: str int_1: Literal[1] int_2: Literal[2] @@ -3034,6 +3041,10 @@ reveal_type(str_a + s) # N: Revealed type is "builtins.str" reveal_type(s + str_a) # N: Revealed type is "builtins.str" reveal_type(str_union_1 + s) # N: Revealed type is "builtins.str" reveal_type(s + str_union_1) # N: Revealed type is "builtins.str" +reveal_type(str_a + str_union_mixed_1) # N: Revealed type is "builtins.str" +reveal_type(str_union_mixed_1 + str_a) # N: Revealed type is "builtins.str | Any" +reveal_type(str_a + str_union_mixed_2) # N: Revealed type is "builtins.str" +reveal_type(str_union_mixed_2 + str_a) # N: Revealed type is "builtins.str" reveal_type(1 + 2) # N: Revealed type is "builtins.int" reveal_type(int_1 + int_2) # N: Revealed type is "Literal[3]" @@ -3071,6 +3082,21 @@ reveal_type("a" + misc_union) # E: Unsupported operand types for + ("str" and " # N: Revealed type is "builtins.str" [builtins fixtures/primitives.pyi] +[case testLiteralAdditionInheritance] +class A: + a = "" + +class B(A): + a = "a" + "b" + +class C: + a = "a" + "b" + +reveal_type(A.a) # N: Revealed type is "builtins.str" +reveal_type(B.a) # N: Revealed type is "builtins.str" +reveal_type(C.a) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] + [case testLiteralAdditionTypedDict] from typing import TypedDict from typing_extensions import Literal From 0964122900a0448c2115baa7335a58ee8e682331 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:08:23 +0100 Subject: [PATCH 7/8] Add guard against too many literal values --- mypy/checkexpr.py | 6 ++++++ test-data/unit/check-literal.test | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 4274392aeadb..f2548213a6da 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -226,6 +226,10 @@ # see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion. MAX_UNIONS: Final = 5 +# Use fallback type if literal addition of unions results in too many literal +# values. Explicitly set on the safe side to prevent accidental issues. +MAX_LITERAL_ADDITION_VALUES: Final = 15 + # Types considered safe for comparisons with --strict-equality due to known behaviour of __eq__. # NOTE: All these types are subtypes of AbstractSet. @@ -3625,6 +3629,8 @@ def literal_expression_addition( ) if len(values) == 1: return LiteralType(values[0], self.named_type(lvalue[1])) + elif len(values) > MAX_LITERAL_ADDITION_VALUES: + return None return make_simplified_union( [LiteralType(val, self.named_type(lvalue[1])) for val in values] ) diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 1aac63f3ea9d..1f472a3de936 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -3111,3 +3111,12 @@ def func(d: LookupDict, pos: Literal["top_", "bottom_", ""]) -> str: [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + +[case testLiteralAdditionGuardMaxValues] +from typing_extensions import Literal + +HexDigit = Literal["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F"] + +def foo(a: HexDigit, b: HexDigit, c: HexDigit) -> None: + reveal_type(a + b + c) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] From 655b81dcd57b5a2c3a3d3a88c2d36f82922ceaf5 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 19 May 2025 12:08:59 +0200 Subject: [PATCH 8/8] Remove type ignore --- mypy/checkexpr.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f2548213a6da..2fde26f1b35f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -8,7 +8,7 @@ from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import contextmanager, nullcontext -from typing import ClassVar, Final, TypeAlias as _TypeAlias, cast, overload +from typing import Any, ClassVar, Final, TypeAlias as _TypeAlias, cast, overload from typing_extensions import assert_never import mypy.checker @@ -3577,7 +3577,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: def literal_value_from_expr( self, expr: Expression, typ: Type - ) -> tuple[list[str | int], str, bool] | None: + ) -> tuple[list[Any], str, bool] | None: if isinstance(expr, StrExpr): return [expr.value], "builtins.str", False if isinstance(expr, IntExpr): @@ -3622,10 +3622,7 @@ def literal_expression_addition( return None values: list[int | str] = sorted( - { - val[0] + val[1] # type: ignore[operator] - for val in itertools.product(lvalue[0], rvalue[0]) - } + {val[0] + val[1] for val in itertools.product(lvalue[0], rvalue[0])} ) if len(values) == 1: return LiteralType(values[0], self.named_type(lvalue[1]))