Skip to content

Commit c7a7b44

Browse files
authored
feat: add PEP 695 type alias support and improve type handling (#711)
1 parent 216b8a5 commit c7a7b44

File tree

8 files changed

+614
-5
lines changed

8 files changed

+614
-5
lines changed

polyfactory/factories/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@
6565
unwrap_optional,
6666
)
6767
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
68-
from polyfactory.utils.predicates import get_type_origin, is_literal, is_optional, is_safe_subclass, is_union
68+
from polyfactory.utils.predicates import (
69+
get_type_origin,
70+
is_literal,
71+
is_optional,
72+
is_safe_subclass,
73+
is_type_var,
74+
is_union,
75+
)
6976
from polyfactory.utils.types import NoneType
7077
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
7178
from polyfactory.value_generators.constrained_collections import (
@@ -280,7 +287,7 @@ class Foo(ModelFactory[MyModel]): # <<< MyModel
280287
b for b in get_original_bases(cls) if get_origin(b) and issubclass(get_origin(b), BaseFactory)
281288
)
282289
generic_args: Sequence[type[T]] = [
283-
arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar)
290+
arg for factory_base in factory_bases for arg in get_args(factory_base) if not is_type_var(arg)
284291
]
285292
if len(generic_args) != 1:
286293
return None
@@ -851,7 +858,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
851858
if provider := (provider_map.get(field_meta.annotation) or provider_map.get(unwrapped_annotation)):
852859
return provider()
853860

854-
if isinstance(unwrapped_annotation, TypeVar):
861+
if is_type_var(unwrapped_annotation):
855862
return create_random_string(cls.__random__, min_length=1, max_length=10)
856863

857864
if callable(unwrapped_annotation):
@@ -933,7 +940,7 @@ def get_field_value_coverage( # noqa: C901,PLR0912
933940
):
934941
yield CoverageContainerCallable(provider)
935942

936-
elif isinstance(unwrapped_annotation, TypeVar):
943+
elif is_type_var(unwrapped_annotation):
937944
yield create_random_string(cls.__random__, min_length=1, max_length=10)
938945

939946
elif callable(unwrapped_annotation):

polyfactory/factories/pydantic_factory.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from polyfactory.field_meta import Constraints, FieldMeta, Null
1717
from polyfactory.utils.deprecation import check_for_deprecated_parameters
1818
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
19+
from polyfactory.utils.normalize_type import normalize_type
1920
from polyfactory.utils.predicates import is_annotated, is_optional, is_safe_subclass, is_union
2021
from polyfactory.utils.types import NoneType
2122
from polyfactory.value_generators.primitives import create_random_bytes
@@ -167,6 +168,10 @@ def from_field_info(
167168
("random", random),
168169
),
169170
)
171+
field_info = FieldInfo.merge_field_infos(
172+
field_info, FieldInfo.from_annotation(normalize_type(field_info.annotation))
173+
)
174+
170175
if callable(field_info.default_factory):
171176
default_value = field_info.default_factory
172177
else:
@@ -220,7 +225,7 @@ def from_field_info(
220225
if is_json:
221226
constraints["json"] = True
222227

223-
result = PydanticFieldMeta.from_type(
228+
result = super().from_type(
224229
annotation=annotation,
225230
children=children,
226231
constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None,

polyfactory/field_meta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from polyfactory.constants import DEFAULT_RANDOM, TYPE_MAPPING
1010
from polyfactory.utils.deprecation import check_for_deprecated_parameters
1111
from polyfactory.utils.helpers import get_annotation_metadata, is_dataclass_instance, unwrap_annotated, unwrap_new_type
12+
from polyfactory.utils.normalize_type import normalize_type
1213
from polyfactory.utils.predicates import is_annotated
1314
from polyfactory.utils.types import NoneType
1415

@@ -139,6 +140,7 @@ def from_type(
139140
),
140141
)
141142

143+
annotation = normalize_type(annotation)
142144
annotated = is_annotated(annotation)
143145
if not constraints and annotated:
144146
metadata = cls.get_constraints_metadata(annotation)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Mapping, Union
4+
5+
from typing_extensions import Annotated, get_args, get_origin
6+
7+
from polyfactory.utils.predicates import (
8+
is_generic_alias,
9+
is_type_alias,
10+
is_type_var,
11+
is_union,
12+
)
13+
14+
15+
def normalize_type(type_annotation: Any) -> Any:
16+
"""Convert modern Python 3.12+ type syntax to standard annotations.
17+
18+
Handles TypeAliasType and GenericAlias types introduced in Python 3.12+,
19+
converting them to standard type annotations when needed.
20+
21+
Args:
22+
type_annotation: Type to normalize (convert if needed or pass through).
23+
24+
Returns:
25+
Normalized type annotation with resolved type aliases and substituted parameters.
26+
27+
Example:
28+
```python
29+
# Python 3.12+
30+
>> from typing import Annotated
31+
>> import annotated_types as at
32+
33+
>> type NegativeInt = Annotated[int, at.Lt(0)]
34+
>> type NonEmptyList[T] = Annotated[list[T], at.Len(1)]
35+
36+
>> normalize_type(NonEmptyList[NegativeInt])
37+
# typing.Annotated[list[typing.Annotated[int, Lt(lt=0)]], Len(min_length=1, max_length=None)]
38+
```
39+
"""
40+
41+
if is_type_alias(type_annotation):
42+
return type_annotation.__value__
43+
44+
if not is_generic_alias(type_annotation):
45+
return type_annotation
46+
47+
origin = get_origin(type_annotation)
48+
args = get_args(type_annotation)
49+
50+
if is_type_alias(origin):
51+
return __handle_generic_type_alias(origin, args)
52+
53+
if args:
54+
normalized_args = tuple(normalize_type(arg) for arg in args)
55+
if normalized_args != args:
56+
return origin[normalized_args[0] if len(normalized_args) == 1 else normalized_args]
57+
58+
return type_annotation
59+
60+
61+
def __handle_generic_type_alias(origin: Any, args: tuple) -> Any:
62+
"""Handle generic type alias with parameters."""
63+
template = origin.__value__
64+
type_params = origin.__type_params__
65+
66+
if not (type_params and args):
67+
return template
68+
69+
normalized_args = tuple(normalize_type(arg) for arg in args)
70+
substitutions = dict(zip(type_params, normalized_args))
71+
72+
if get_origin(template) is Annotated:
73+
base_type, *metadata = get_args(template)
74+
template_result = Annotated[tuple([__apply_substitutions(base_type, substitutions)] + metadata)] # type: ignore[valid-type]
75+
else:
76+
template_result = __apply_substitutions(template, substitutions)
77+
78+
return template_result
79+
80+
81+
def __apply_substitutions(target: Any, subs: Mapping[Any, Any]) -> Any:
82+
if is_type_var(target):
83+
return subs.get(target, target)
84+
85+
if is_union(target):
86+
args = tuple(__apply_substitutions(arg, subs) for arg in get_args(target))
87+
return Union[args]
88+
89+
origin = get_origin(target)
90+
args = get_args(target)
91+
92+
if is_type_alias(origin):
93+
sub_args = tuple(__apply_substitutions(arg, subs) for arg in args) if args else ()
94+
return normalize_type(origin[sub_args] if sub_args else origin)
95+
96+
if origin and args:
97+
sub_args = tuple(__apply_substitutions(arg, subs) for arg in args)
98+
return origin[sub_args[0] if len(sub_args) == 1 else sub_args]
99+
100+
return target

polyfactory/utils/predicates.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,28 @@ def is_type_alias(annotation: Any) -> TypeGuard[TypeAliasType]:
147147
return isinstance(annotation, AllTypeAliasTypes)
148148

149149

150+
def is_generic_alias(annotation: Any) -> bool:
151+
"""Determine if the given type annotation is a generic alias.
152+
153+
:param annotation: A type annotation.
154+
155+
:returns: A boolean
156+
"""
157+
return hasattr(annotation, "__origin__") and hasattr(annotation, "__args__")
158+
159+
160+
def is_type_var(annotation: Any) -> TypeGuard[TypeVar]:
161+
"""Determine if the given type annotation is a TypeVar.
162+
163+
Args:
164+
annotation: A type annotation.
165+
166+
Returns:
167+
A boolean.
168+
"""
169+
return isinstance(annotation, TypeVar)
170+
171+
150172
def get_type_origin(annotation: Any) -> Any:
151173
"""Get the type origin of an annotation - safely.
152174

tests/test_dataclass_factory.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import dataclasses
2+
import sys
3+
import textwrap
24
from dataclasses import dataclass as vanilla_dataclass
35
from dataclasses import field
46
from types import ModuleType
@@ -351,3 +353,23 @@ class AFactory(DataclassFactory[A]):
351353

352354
if isinstance(a.c, list):
353355
assert all(isinstance(value, int) for value in a.c)
356+
357+
358+
@pytest.mark.skipif(sys.version_info < (3, 12), reason="PEP 695 requires Python 3.12+")
359+
def test_pep695_dict_union_types(create_module: Callable[[str], ModuleType]) -> None:
360+
"""Test type aliases with dict unions."""
361+
module = create_module(
362+
textwrap.dedent("""
363+
from dataclasses import dataclass
364+
365+
type IntDict = dict[str, int]
366+
type StrDict = dict[str, str]
367+
type MixedDict = IntDict | StrDict
368+
369+
@dataclass
370+
class Foo:
371+
data: MixedDict
372+
nested: dict[str, MixedDict]
373+
""")
374+
)
375+
DataclassFactory.create_factory(module.Foo).build()

0 commit comments

Comments
 (0)