Skip to content

Commit e37d92d

Browse files
authored
[mypyc] Support iterating over keys/values/items of dict-bound TypeVar and ParamSpec.kwargs (#18789)
Fixes #18784.
1 parent a8b723d commit e37d92d

File tree

4 files changed

+779
-25
lines changed

4 files changed

+779
-25
lines changed

mypyc/irbuild/builder.py

+30-24
Original file line numberDiff line numberDiff line change
@@ -958,38 +958,44 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]:
958958
959959
This is useful for dict subclasses like SymbolTable.
960960
"""
961-
target_type = get_proper_type(self.types[expr])
961+
return self.get_dict_base_type_from_type(self.types[expr])
962+
963+
def get_dict_base_type_from_type(self, target_type: Type) -> list[Instance]:
964+
target_type = get_proper_type(target_type)
962965
if isinstance(target_type, UnionType):
963-
types = [get_proper_type(item) for item in target_type.items]
966+
return [
967+
inner
968+
for item in target_type.items
969+
for inner in self.get_dict_base_type_from_type(item)
970+
]
971+
if isinstance(target_type, TypeVarLikeType):
972+
# Match behaviour of self.node_type
973+
# We can only reach this point if `target_type` was a TypeVar(bound=dict[...])
974+
# or a ParamSpec.
975+
return self.get_dict_base_type_from_type(target_type.upper_bound)
976+
977+
if isinstance(target_type, TypedDictType):
978+
target_type = target_type.fallback
979+
dict_base = next(
980+
base for base in target_type.type.mro if base.fullname == "typing.Mapping"
981+
)
982+
elif isinstance(target_type, Instance):
983+
dict_base = next(
984+
base for base in target_type.type.mro if base.fullname == "builtins.dict"
985+
)
964986
else:
965-
types = [target_type]
966-
967-
dict_types = []
968-
for t in types:
969-
if isinstance(t, TypedDictType):
970-
t = t.fallback
971-
dict_base = next(base for base in t.type.mro if base.fullname == "typing.Mapping")
972-
else:
973-
assert isinstance(t, Instance), t
974-
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
975-
dict_types.append(map_instance_to_supertype(t, dict_base))
976-
return dict_types
987+
assert False, f"Failed to extract dict base from {target_type}"
988+
return [map_instance_to_supertype(target_type, dict_base)]
977989

978990
def get_dict_key_type(self, expr: Expression) -> RType:
979991
dict_base_types = self.get_dict_base_type(expr)
980-
if len(dict_base_types) == 1:
981-
return self.type_to_rtype(dict_base_types[0].args[0])
982-
else:
983-
rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types]
984-
return RUnion.make_simplified_union(rtypes)
992+
rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types]
993+
return RUnion.make_simplified_union(rtypes)
985994

986995
def get_dict_value_type(self, expr: Expression) -> RType:
987996
dict_base_types = self.get_dict_base_type(expr)
988-
if len(dict_base_types) == 1:
989-
return self.type_to_rtype(dict_base_types[0].args[1])
990-
else:
991-
rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types]
992-
return RUnion.make_simplified_union(rtypes)
997+
rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types]
998+
return RUnion.make_simplified_union(rtypes)
993999

9941000
def get_dict_item_type(self, expr: Expression) -> RType:
9951001
key_type = self.get_dict_key_type(expr)

0 commit comments

Comments
 (0)