Skip to content

Commit 0bf96e6

Browse files
committed
feat: add runtime conversion function for ProjectTable
Signed-off-by: Henry Schreiner <[email protected]>
1 parent 7da2b4f commit 0bf96e6

File tree

2 files changed

+370
-22
lines changed

2 files changed

+370
-22
lines changed

src/packaging/project_table.py

Lines changed: 125 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,17 @@
99
from typing import Any, Dict, List, Literal, TypedDict, Union
1010

1111
if sys.version_info < (3, 11):
12-
from typing_extensions import Required
12+
if typing.TYPE_CHECKING:
13+
from typing_extensions import Required
14+
else:
15+
try:
16+
from typing_extensions import Required
17+
except ModuleNotFoundError:
18+
V = typing.TypeVar("V")
19+
20+
class Required:
21+
def __class_getitem__(cls, item: V) -> V:
22+
return item
1323
else:
1424
from typing import Required
1525

@@ -23,6 +33,7 @@
2333
"ProjectTable",
2434
"PyProjectTable",
2535
"ReadmeTable",
36+
"to_project_table",
2637
]
2738

2839

@@ -31,11 +42,17 @@ def __dir__() -> list[str]:
3142

3243

3344
class ContactTable(TypedDict, total=False):
45+
"""
46+
Can have either name or email.
47+
"""
3448
name: str
3549
email: str
3650

3751

3852
class LicenseTable(TypedDict, total=False):
53+
"""
54+
Can have either text or file. Legacy.
55+
"""
3956
text: str
4057
file: str
4158

@@ -121,25 +138,111 @@ class LicenseTable(TypedDict, total=False):
121138
total=False,
122139
)
123140

124-
# Tests for type checking
125-
if typing.TYPE_CHECKING:
126-
PyProjectTable(
127-
{
128-
"build-system": BuildSystemTable(
129-
{"build-backend": "one", "requires": ["two"]}
130-
),
131-
"project": ProjectTable(
132-
{
133-
"name": "one",
134-
"version": "0.1.0",
135-
}
136-
),
137-
"tool": {"thing": object()},
138-
"dependency-groups": {
139-
"one": [
140-
"one",
141-
IncludeGroupTable({"include-group": "two"}),
142-
]
141+
T = typing.TypeVar("T")
142+
143+
144+
def is_typed_dict(type_hint: Any) -> bool:
145+
if sys.version_info >= (3, 10):
146+
return typing.is_typeddict(type_hint)
147+
return hasattr(type_hint, "__annotations__") and hasattr(type_hint, "__total__")
148+
149+
150+
def _cast(type_hint: type[T], data: Any, prefix: str) -> T:
151+
"""
152+
Runtime cast for types.
153+
154+
Just enough to cover the dicts above (not general or public).
155+
"""
156+
157+
# TypedDict
158+
if is_typed_dict(type_hint):
159+
if not isinstance(data, dict):
160+
msg = (
161+
f'"{prefix}" expected dict for {type_hint.__name__},'
162+
f" got {type(data).__name__}"
163+
)
164+
raise TypeError(msg)
165+
166+
hints = typing.get_type_hints(type_hint)
167+
for key, typ in hints.items():
168+
if key in data:
169+
_cast(typ, data[key], prefix + f".{key}" if prefix else key)
170+
# Required keys could be enforced here on 3.11+ eventually
171+
172+
return typing.cast("T", data)
173+
174+
origin = typing.get_origin(type_hint)
175+
# Special case Required on 3.10
176+
if origin is Required:
177+
type_hint, = typing.get_args(type_hint)
178+
origin = typing.get_origin(type_hint)
179+
args = typing.get_args(type_hint)
180+
181+
# Literal
182+
if origin is typing.Literal:
183+
if data not in args:
184+
arg_names = ", ".join(repr(a) for a in args)
185+
msg = f'"{prefix}" expected one of {arg_names}, got {data!r}'
186+
raise TypeError(msg)
187+
return typing.cast("T", data)
188+
189+
# Any accepts everything, so no validation
190+
if type_hint is Any:
191+
return typing.cast("T", data)
192+
193+
# List[T]
194+
if origin is list:
195+
if not isinstance(data, list):
196+
msg = f'"{prefix}" expected list, got {type(data).__name__}'
197+
raise TypeError(msg)
198+
item_type = args[0]
199+
return typing.cast(
200+
"T", [_cast(item_type, item, f"{prefix}[]") for item in data]
201+
)
202+
203+
# Dict[str, T]
204+
if origin is dict:
205+
if not isinstance(data, dict):
206+
msg = f'"{prefix}" expected dict, got {type(data).__name__}'
207+
raise TypeError(msg)
208+
_, value_type = args
209+
return typing.cast(
210+
"T",
211+
{
212+
key: _cast(value_type, value, f"{prefix}.{key}")
213+
for key, value in data.items()
143214
},
144-
}
145-
)
215+
)
216+
# Union[T1, T2, ...]
217+
if origin is typing.Union:
218+
for arg in args:
219+
try:
220+
_cast(arg, data, prefix)
221+
return typing.cast("T", data)
222+
except TypeError: # noqa: PERF203
223+
continue
224+
arg_names = " | ".join(a.__name__ for a in args)
225+
msg = f'"{prefix}" does not match any type in {arg_names}'
226+
raise TypeError(msg)
227+
228+
# Base case (str, etc.)
229+
if isinstance(data, origin or type_hint):
230+
return data
231+
232+
msg = f'"{prefix}" expected {type_hint.__name__}, got {type(data).__name__}'
233+
raise TypeError(msg)
234+
235+
236+
def to_project_table(data: dict[str, Any], /) -> PyProjectTable:
237+
"""
238+
Convert a dict to a PyProjectTable, validating types at runtime.
239+
240+
Note that only the types that are affected by a TypedDict are validated;
241+
extra keys are ignored.
242+
"""
243+
# Handling Required here
244+
name = data.get("project", {"name": ""}).get("name")
245+
if name is None:
246+
msg = 'Key "project.name" is required if "project" is present'
247+
raise TypeError(msg)
248+
return _cast(PyProjectTable, data, "")

0 commit comments

Comments
 (0)