Skip to content

Commit c0af3fb

Browse files
committed
feat: add runtime conversion function for ProjectTable
Signed-off-by: Henry Schreiner <[email protected]>
1 parent 36b53e1 commit c0af3fb

File tree

2 files changed

+372
-22
lines changed

2 files changed

+372
-22
lines changed

src/packaging/project_table.py

Lines changed: 127 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,17 @@
99
from typing import Any, Dict, List, Literal, TypedDict, Union
1010

1111
if sys.version_info < (3, 11):
12-
from typing_extensions import Required
12+
if typing.TYPE_CHECKING:
13+
from typing_extensions import Required
14+
else:
15+
try:
16+
from typing_extensions import Required
17+
except ModuleNotFoundError:
18+
V = typing.TypeVar("V")
19+
20+
class Required:
21+
def __class_getitem__(cls, item: V) -> V:
22+
return item
1323
else:
1424
from typing import Required
1525

@@ -23,6 +33,7 @@
2333
"ProjectTable",
2434
"PyProjectTable",
2535
"ReadmeTable",
36+
"to_project_table",
2637
]
2738

2839

@@ -31,11 +42,19 @@ def __dir__() -> list[str]:
3142

3243

3344
class ContactTable(TypedDict, total=False):
45+
"""
46+
Can have either name or email.
47+
"""
48+
3449
name: str
3550
email: str
3651

3752

3853
class LicenseTable(TypedDict, total=False):
54+
"""
55+
Can have either text or file. Legacy.
56+
"""
57+
3958
text: str
4059
file: str
4160

@@ -121,25 +140,111 @@ class LicenseTable(TypedDict, total=False):
121140
total=False,
122141
)
123142

124-
# Tests for type checking
125-
if typing.TYPE_CHECKING:
126-
PyProjectTable(
127-
{
128-
"build-system": BuildSystemTable(
129-
{"build-backend": "one", "requires": ["two"]}
130-
),
131-
"project": ProjectTable(
132-
{
133-
"name": "one",
134-
"version": "0.1.0",
135-
}
136-
),
137-
"tool": {"thing": object()},
138-
"dependency-groups": {
139-
"one": [
140-
"one",
141-
IncludeGroupTable({"include-group": "two"}),
142-
]
143+
T = typing.TypeVar("T")
144+
145+
146+
def is_typed_dict(type_hint: object) -> bool:
147+
if sys.version_info >= (3, 10):
148+
return typing.is_typeddict(type_hint)
149+
return hasattr(type_hint, "__annotations__") and hasattr(type_hint, "__total__")
150+
151+
152+
def _cast(type_hint: type[T], data: object, prefix: str) -> T:
153+
"""
154+
Runtime cast for types.
155+
156+
Just enough to cover the dicts above (not general or public).
157+
"""
158+
159+
# TypedDict
160+
if is_typed_dict(type_hint):
161+
if not isinstance(data, dict):
162+
msg = (
163+
f'"{prefix}" expected dict for {type_hint.__name__},'
164+
f" got {type(data).__name__}"
165+
)
166+
raise TypeError(msg)
167+
168+
hints = typing.get_type_hints(type_hint)
169+
for key, typ in hints.items():
170+
if key in data:
171+
_cast(typ, data[key], prefix + f".{key}" if prefix else key)
172+
# Required keys could be enforced here on 3.11+ eventually
173+
174+
return typing.cast("T", data)
175+
176+
origin = typing.get_origin(type_hint)
177+
# Special case Required on 3.10
178+
if origin is Required:
179+
(type_hint,) = typing.get_args(type_hint)
180+
origin = typing.get_origin(type_hint)
181+
args = typing.get_args(type_hint)
182+
183+
# Literal
184+
if origin is typing.Literal:
185+
if data not in args:
186+
arg_names = ", ".join(repr(a) for a in args)
187+
msg = f'"{prefix}" expected one of {arg_names}, got {data!r}'
188+
raise TypeError(msg)
189+
return typing.cast("T", data)
190+
191+
# Any accepts everything, so no validation
192+
if type_hint is Any:
193+
return typing.cast("T", data)
194+
195+
# List[T]
196+
if origin is list:
197+
if not isinstance(data, list):
198+
msg = f'"{prefix}" expected list, got {type(data).__name__}'
199+
raise TypeError(msg)
200+
item_type = args[0]
201+
return typing.cast(
202+
"T", [_cast(item_type, item, f"{prefix}[]") for item in data]
203+
)
204+
205+
# Dict[str, T]
206+
if origin is dict:
207+
if not isinstance(data, dict):
208+
msg = f'"{prefix}" expected dict, got {type(data).__name__}'
209+
raise TypeError(msg)
210+
_, value_type = args
211+
return typing.cast(
212+
"T",
213+
{
214+
key: _cast(value_type, value, f"{prefix}.{key}")
215+
for key, value in data.items()
143216
},
144-
}
145-
)
217+
)
218+
# Union[T1, T2, ...]
219+
if origin is typing.Union:
220+
for arg in args:
221+
try:
222+
_cast(arg, data, prefix)
223+
return typing.cast("T", data)
224+
except TypeError: # noqa: PERF203
225+
continue
226+
arg_names = " | ".join(a.__name__ for a in args)
227+
msg = f'"{prefix}" does not match any type in {arg_names}'
228+
raise TypeError(msg)
229+
230+
# Base case (str, etc.)
231+
if isinstance(data, origin or type_hint):
232+
return typing.cast("T", data)
233+
234+
msg = f'"{prefix}" expected {type_hint.__name__}, got {type(data).__name__}'
235+
raise TypeError(msg)
236+
237+
238+
def to_project_table(data: dict[str, Any], /) -> PyProjectTable:
239+
"""
240+
Convert a dict to a PyProjectTable, validating types at runtime.
241+
242+
Note that only the types that are affected by a TypedDict are validated;
243+
extra keys are ignored.
244+
"""
245+
# Handling Required here
246+
name = data.get("project", {"name": ""}).get("name")
247+
if name is None:
248+
msg = 'Key "project.name" is required if "project" is present'
249+
raise TypeError(msg)
250+
return _cast(PyProjectTable, data, "")

0 commit comments

Comments
 (0)