diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index 9464599108..12cfd015fc 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -52,6 +52,7 @@ from litestar.types import Empty from litestar.types.builtin_types import NoneType from litestar.typing import FieldDefinition +from litestar.utils import deprecated from litestar.utils.helpers import get_name from litestar.utils.predicates import ( is_class_and_subclass, @@ -309,8 +310,6 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re if field_definition.is_new_type: result = self.for_new_type(field_definition) - elif field_definition.is_type_alias_type: - result = self.for_type_alias_type(field_definition) elif plugin_for_annotation := self.get_plugin_for(field_definition): result = self.for_plugin(field_definition, plugin_for_annotation) elif _should_create_literal_schema(field_definition): @@ -353,6 +352,7 @@ def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference: ) ) + @deprecated(version="2.15", removal_in="3.0", info="TypeAliasType is supported natively") def for_type_alias_type(self, field_definition: FieldDefinition) -> Schema | Reference: return self.for_field_definition( FieldDefinition.from_kwarg( diff --git a/litestar/typing.py b/litestar/typing.py index 37dec75825..5877a4b488 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -272,7 +272,7 @@ def is_new_type(self) -> bool: @property def is_type_alias_type(self) -> bool: """Whether the annotation is a ``TypeAliasType``""" - return isinstance(self.annotation, TypeAliasType) + return TypeAliasType in self.type_wrappers @property def is_type_var(self) -> bool: diff --git a/litestar/utils/typing.py b/litestar/utils/typing.py index bf6de688e6..77aded3dd1 100644 --- a/litestar/utils/typing.py +++ b/litestar/utils/typing.py @@ -37,7 +37,16 @@ cast, ) -from typing_extensions import Annotated, NewType, NotRequired, Required, get_args, get_origin, get_type_hints +from typing_extensions import ( + Annotated, + NewType, + NotRequired, + Required, + TypeAliasType, + get_args, + get_origin, + get_type_hints, +) from litestar.types.builtin_types import NoneType, UnionTypes @@ -128,7 +137,7 @@ ``collections.abc.Mapping``, are not valid generic types in Python 3.8. """ -wrapper_type_set = {Annotated, Required, NotRequired} +wrapper_type_set = {Annotated, Required, NotRequired, TypeAliasType} """Types that always contain a wrapped type annotation as their first arg.""" @@ -151,7 +160,7 @@ def make_non_optional_union(annotation: UnionT | None) -> UnionT: def unwrap_annotation(annotation: Any) -> tuple[Any, tuple[Any, ...], set[Any]]: - """Remove "wrapper" annotation types, such as ``Annotated``, ``Required``, and ``NotRequired``. + """Remove "wrapper" annotation types, such as ``Annotated``, ``Required``, ``NotRequired`` or ``TypeAliasType``. Note: ``annotation`` should have been retrieved from :func:`get_type_hints()` with ``include_extras=True``. This @@ -163,14 +172,30 @@ def unwrap_annotation(annotation: Any) -> tuple[Any, tuple[Any, ...], set[Any]]: Returns: A tuple of the unwrapped annotation and any ``Annotated`` metadata, and a set of any wrapper types encountered. """ - origin = get_origin(annotation) - wrappers = set() + metadata = [] - while origin in wrapper_type_set: - wrappers.add(origin) - annotation, *meta = get_args(annotation) - metadata.extend(meta) - origin = get_origin(annotation) + wrappers = set() + + stack = [annotation] + + while stack: + ann = stack.pop() + + if isinstance(ann, TypeAliasType): + wrappers.add(TypeAliasType) + stack.append(ann.__value__) + continue + + origin = get_origin(ann) + if origin in wrapper_type_set: + ann, *meta = get_args(ann) + metadata.extend(meta) + wrappers.add(origin) + stack.append(ann) + continue + + return ann, tuple(metadata), wrappers + return annotation, tuple(metadata), wrappers diff --git a/tests/unit/test_typing.py b/tests/unit/test_typing.py index bb05952839..e52bf37f93 100644 --- a/tests/unit/test_typing.py +++ b/tests/unit/test_typing.py @@ -12,6 +12,7 @@ from litestar.exceptions import LitestarWarning from litestar.params import DependencyKwarg, KwargDefinition, Parameter, ParameterKwarg from litestar.typing import FieldDefinition +from litestar.utils.typing import unwrap_annotation from tests.unit.test_utils.test_signature import T, _check_field_definition, field_definition_int, test_type_hints @@ -465,6 +466,7 @@ def handler(foo: Annotated[int, Parameter(default=1)]) -> None: def test_is_type_alias_type() -> None: field_definition = FieldDefinition.from_annotation(TypeAliasType("IntAlias", int)) # pyright: ignore assert field_definition.is_type_alias_type + assert field_definition.annotation == int @pytest.mark.skipif(sys.version_info < (3, 12), reason="type keyword not available before 3.12") @@ -474,3 +476,68 @@ def test_unwrap_type_alias_type_keyword() -> None: annotation = ctx["IntAlias"] field_definition = FieldDefinition.from_annotation(annotation) assert field_definition.is_type_alias_type + assert field_definition.annotation == int + + +def type_kw_or(src: str) -> Any: + if sys.version_info < (3, 12): + return None + + ctx: dict[str, Any] = {} # type: ignore[unreachable] + + exec(src, ctx, None) + return ctx["Alias"] + + +@pytest.mark.parametrize( + "annotation", + [ + TypeAliasType("SomeAlias", int), + pytest.param( + type_kw_or("type Alias = int"), + marks=pytest.mark.skipif(sys.version_info < (3, 12), reason="type keyword not available before 3.12"), + ), + ], +) +def test_unwrap_annotation_type_alias_type(annotation: Any) -> None: + unwrapped, metadata, wrappers = unwrap_annotation(annotation) + assert unwrapped == int + assert not metadata + assert TypeAliasType in wrappers + + +NestedAlias = TypeAliasType("NestedAlias", Union[Annotated[int, "meta"], List["NestedAlias"]]) # type: ignore[misc] + + +@pytest.mark.parametrize( + "annotation, expected_meta, expected_type", + [ + (Annotated[TypeAliasType("SomeAlias", int), "meta"], ("meta",), int), + (TypeAliasType("SomeAlias", TypeAliasType("InnerAlias", int)), (), int), + (TypeAliasType("SomeAlias", Annotated[int, "meta"]), ("meta",), int), + (TypeAliasType("SomeAlias", Annotated[TypeAliasType("InnerAlias", int), "meta"]), ("meta",), int), + (Annotated[TypeAliasType("SomeAlias", Annotated[int, "inner meta"]), "meta"], ("meta", "inner meta"), int), + (NestedAlias, ("meta",), NestedAlias), + ], +) +def test_unwrap_annotation_type_alias_type_nested( + annotation: Any, expected_meta: tuple[str, ...], expected_type: Any +) -> None: + annotation, metadata, wrappers = unwrap_annotation(annotation) + assert annotation == expected_type + assert metadata == expected_meta + assert TypeAliasType in wrappers + + +def test_unwrap_annotation_type_alias_type_nested_with_type_kw() -> None: + annotation = Annotated[type_kw_or("type Alias = int"), "meta"] # type: ignore[valid-type] + unwrapped, metadata, wrappers = unwrap_annotation(annotation) + assert unwrapped == int + assert metadata == ("meta",) + assert TypeAliasType in wrappers + + +def test_unwrap_annotation_type_alias_type_undefined() -> None: + annotation = type_kw_or("type Alias = NonExistent") + with pytest.raises(NameError): + unwrap_annotation(annotation)