diff --git a/federation-compatibility/schema.py b/federation-compatibility/schema.py index 7502f59d93..bce12f6942 100644 --- a/federation-compatibility/schema.py +++ b/federation-compatibility/schema.py @@ -121,13 +121,13 @@ class Custom: ... @strawberry.federation.type(extend=True, keys=["email"]) class User: email: strawberry.ID = strawberry.federation.field(external=True) - name: Optional[str] = strawberry.federation.field(override="users") - total_products_created: Optional[int] = strawberry.federation.field(external=True) + name: str | None = strawberry.federation.field(override="users") + total_products_created: int | None = strawberry.federation.field(external=True) years_of_employment: int = strawberry.federation.field(external=True) # TODO: the camel casing will be fixed in a future release of Strawberry @strawberry.federation.field(requires=["totalProductsCreated", "yearsOfEmployment"]) - def average_products_created_per_year(self) -> Optional[int]: + def average_products_created_per_year(self) -> int | None: if self.total_products_created is not None: return round(self.total_products_created / self.years_of_employment) @@ -150,9 +150,9 @@ def resolve_reference(cls, **data: Any) -> Optional["User"]: @strawberry.federation.type(shareable=True) class ProductDimension: - size: Optional[str] - weight: Optional[float] - unit: Optional[str] = strawberry.federation.field(inaccessible=True) + size: str | None + weight: float | None + unit: str | None = strawberry.federation.field(inaccessible=True) @strawberry.type @@ -163,13 +163,13 @@ class ProductVariation: @strawberry.type class CaseStudy: case_number: strawberry.ID - description: Optional[str] + description: str | None @strawberry.federation.type(keys=["study { caseNumber }"]) class ProductResearch: study: CaseStudy - outcome: Optional[str] + outcome: str | None @classmethod def from_data(cls, data: dict) -> "ProductResearch": @@ -206,8 +206,8 @@ def resolve_reference(cls, **data: Any) -> Optional["ProductResearch"]: class DeprecatedProduct: sku: str package: str - reason: Optional[str] - created_by: Optional[User] + reason: str | None + created_by: User | None @classmethod def resolve_reference(cls, **data: Any) -> Optional["DeprecatedProduct"]: @@ -231,12 +231,12 @@ def resolve_reference(cls, **data: Any) -> Optional["DeprecatedProduct"]: ) class Product: id: strawberry.ID - sku: Optional[str] - package: Optional[str] + sku: str | None + package: str | None variation_id: strawberry.Private[str] @strawberry.field - def variation(self) -> Optional[ProductVariation]: + def variation(self) -> ProductVariation | None: return ( ProductVariation(strawberry.ID(self.variation_id)) if self.variation_id @@ -244,14 +244,14 @@ def variation(self) -> Optional[ProductVariation]: ) @strawberry.field - def dimensions(self) -> Optional[ProductDimension]: + def dimensions(self) -> ProductDimension | None: return ProductDimension(**dimension) @strawberry.federation.field(provides=["totalProductsCreated"]) - def created_by(self) -> Optional[User]: + def created_by(self) -> User | None: return User(**user) - notes: Optional[str] = strawberry.federation.field(tags=["internal"]) + notes: str | None = strawberry.federation.field(tags=["internal"]) research: list[ProductResearch] @classmethod @@ -301,10 +301,10 @@ def resolve_reference(cls, id: strawberry.ID) -> "Inventory": @strawberry.federation.type(extend=True) class Query: - product: Optional[Product] = strawberry.field(resolver=get_product_by_id) + product: Product | None = strawberry.field(resolver=get_product_by_id) @strawberry.field(deprecation_reason="Use product query instead") - def deprecated_product(self, sku: str, package: str) -> Optional[DeprecatedProduct]: + def deprecated_product(self, sku: str, package: str) -> DeprecatedProduct | None: return None diff --git a/pyproject.toml b/pyproject.toml index 73d44dd2b2..b4c15f3e92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,13 +225,6 @@ src = ["strawberry", "tests"] [tool.ruff.lint] select = ["ALL"] ignore = [ - # https://github.com/astral-sh/ruff/pull/4427 - # equivalent to keep-runtime-typing. We might want to enable those - # after we drop support for Python 3.9 - "UP006", - "UP007", - "UP045", - # we use asserts in tests and to hint mypy "S101", @@ -372,6 +365,8 @@ ignore = [ "TCH002", "TCH003", "TRY002", + "UP007", + "UP045", ] [tool.ruff.lint.isort] diff --git a/strawberry/aiohttp/test/client.py b/strawberry/aiohttp/test/client.py index 75b1fda4d6..0367486c37 100644 --- a/strawberry/aiohttp/test/client.py +++ b/strawberry/aiohttp/test/client.py @@ -4,7 +4,6 @@ from typing import ( TYPE_CHECKING, Any, - Optional, ) from strawberry.test.client import BaseGraphQLTestClient, Response @@ -17,11 +16,11 @@ class GraphQLTestClient(BaseGraphQLTestClient): async def query( self, query: str, - variables: Optional[dict[str, Mapping]] = None, - headers: Optional[dict[str, object]] = None, - asserts_errors: Optional[bool] = None, - files: Optional[dict[str, object]] = None, - assert_no_errors: Optional[bool] = True, + variables: dict[str, Mapping] | None = None, + headers: dict[str, object] | None = None, + asserts_errors: bool | None = None, + files: dict[str, object] | None = None, + assert_no_errors: bool | None = True, ) -> Response: body = self._build_body(query, variables, files) @@ -54,8 +53,8 @@ async def query( async def request( self, body: dict[str, object], - headers: Optional[dict[str, object]] = None, - files: Optional[dict[str, object]] = None, + headers: dict[str, object] | None = None, + files: dict[str, object] | None = None, ) -> Any: return await self._client.post( self.url, diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index e41bfba977..0d234e99e0 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -6,9 +6,7 @@ from json.decoder import JSONDecodeError from typing import ( TYPE_CHECKING, - Optional, TypeGuard, - Union, ) from lia import AiohttpHTTPRequestAdapter, HTTPException @@ -72,7 +70,7 @@ async def close(self, code: int, reason: str) -> None: class GraphQLView( AsyncBaseHTTPView[ web.Request, - Union[web.Response, web.StreamResponse], + web.Response | web.StreamResponse, web.Response, web.Request, web.WebSocketResponse, @@ -91,8 +89,8 @@ class GraphQLView( def __init__( self, schema: BaseSchema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = True, keep_alive_interval: float = 1, @@ -131,12 +129,12 @@ def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]: ws = web.WebSocketResponse(protocols=self.subscription_protocols) return ws.can_prepare(request).ok - async def pick_websocket_subprotocol(self, request: web.Request) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: web.Request) -> str | None: ws = web.WebSocketResponse(protocols=self.subscription_protocols) return ws.can_prepare(request).protocol async def create_websocket_response( - self, request: web.Request, subprotocol: Optional[str] + self, request: web.Request, subprotocol: str | None ) -> web.WebSocketResponse: protocols = [subprotocol] if subprotocol else [] ws = web.WebSocketResponse(protocols=protocols) @@ -152,17 +150,17 @@ async def __call__(self, request: web.Request) -> web.StreamResponse: status=e.status_code, ) - async def get_root_value(self, request: web.Request) -> Optional[RootValue]: + async def get_root_value(self, request: web.Request) -> RootValue | None: return None async def get_context( - self, request: web.Request, response: Union[web.Response, web.WebSocketResponse] + self, request: web.Request, response: web.Response | web.WebSocketResponse ) -> Context: return {"request": request, "response": response} # type: ignore def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: web.Response, ) -> web.Response: sub_response.text = self.encode_json(response_data) diff --git a/strawberry/annotation.py b/strawberry/annotation.py index 4b9bca1bf0..fea5d5cbb1 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -11,7 +11,6 @@ Annotated, Any, ForwardRef, - Optional, TypeVar, Union, cast, @@ -61,14 +60,14 @@ class StrawberryAnnotation: def __init__( self, - annotation: Union[object, str], + annotation: object | str, *, - namespace: Optional[dict[str, Any]] = None, + namespace: dict[str, Any] | None = None, ) -> None: self.raw_annotation = annotation self.namespace = namespace - self.__resolve_cache__: Optional[Union[StrawberryType, type]] = None + self.__resolve_cache__: StrawberryType | type | None = None def __eq__(self, other: object) -> bool: if not isinstance(other, StrawberryAnnotation): @@ -81,8 +80,8 @@ def __hash__(self) -> int: @staticmethod def from_annotation( - annotation: object, namespace: Optional[dict[str, Any]] = None - ) -> Optional[StrawberryAnnotation]: + annotation: object, namespace: dict[str, Any] | None = None + ) -> StrawberryAnnotation | None: if annotation is None: return None @@ -91,7 +90,7 @@ def from_annotation( return annotation @property - def annotation(self) -> Union[object, str]: + def annotation(self) -> object | str: """Return evaluated type on success or fallback to raw (string) annotation.""" try: return self.evaluate() @@ -101,7 +100,7 @@ def annotation(self) -> Union[object, str]: return self.raw_annotation @annotation.setter - def annotation(self, value: Union[object, str]) -> None: + def annotation(self, value: object | str) -> None: self.raw_annotation = value self.__resolve_cache__ = None @@ -131,8 +130,8 @@ def _get_type_with_args( def resolve( self, *, - type_definition: Optional[StrawberryObjectDefinition] = None, - ) -> Union[StrawberryType, type]: + type_definition: StrawberryObjectDefinition | None = None, + ) -> StrawberryType | type: """Return resolved (transformed) annotation.""" if (resolved := self.__resolve_cache__) is None: resolved = self._resolve() @@ -161,11 +160,11 @@ def resolve( return resolved - def _resolve(self) -> Union[StrawberryType, type]: + def _resolve(self) -> StrawberryType | type: evaled_type = cast("Any", self.evaluate()) return self._resolve_evaled_type(evaled_type) - def _resolve_evaled_type(self, evaled_type: Any) -> Union[StrawberryType, type]: + def _resolve_evaled_type(self, evaled_type: Any) -> StrawberryType | type: if is_private(evaled_type): return evaled_type @@ -247,7 +246,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional: # passed as we can safely use `Union` for both optional types # (e.g. `Optional[str]`) and optional unions (e.g. # `Optional[Union[TypeA, TypeB]]`) - child_type = Union[non_optional_types] # type: ignore + child_type = Union[non_optional_types] # type: ignore # noqa: UP007 of_type = StrawberryAnnotation( annotation=child_type, @@ -324,7 +323,7 @@ def _is_enum(cls, annotation: Any) -> bool: return issubclass(annotation, Enum) @classmethod - def _is_type_generic(cls, type_: Union[StrawberryType, type]) -> bool: + def _is_type_generic(cls, type_: StrawberryType | type) -> bool: """Returns True if `resolver_type` is generic else False.""" from strawberry.types.base import StrawberryType diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 7cdfa9df71..7a1a7d2326 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -5,9 +5,7 @@ from json import JSONDecodeError from typing import ( TYPE_CHECKING, - Optional, TypeGuard, - Union, ) from lia import HTTPException, StarletteRequestAdapter @@ -103,8 +101,8 @@ class GraphQL( def __init__( self, schema: BaseSchema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -149,19 +147,17 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: else: # pragma: no cover raise ValueError("Unknown scope type: {!r}".format(scope["type"])) - async def get_root_value( - self, request: Union[Request, WebSocket] - ) -> Optional[RootValue]: + async def get_root_value(self, request: Request | WebSocket) -> RootValue | None: return None async def get_context( - self, request: Union[Request, WebSocket], response: Union[Response, WebSocket] + self, request: Request | WebSocket, response: Response | WebSocket ) -> Context: return {"request": request, "response": response} # type: ignore async def get_sub_response( self, - request: Union[Request, WebSocket], + request: Request | WebSocket, ) -> Response: sub_response = Response() sub_response.status_code = None # type: ignore @@ -174,7 +170,7 @@ async def render_graphql_ide(self, request: Request) -> Response: def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: Response, ) -> Response: response = Response( @@ -210,18 +206,18 @@ async def create_streaming_response( ) def is_websocket_request( - self, request: Union[Request, WebSocket] + self, request: Request | WebSocket ) -> TypeGuard[WebSocket]: return request.scope["type"] == "websocket" - async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: WebSocket) -> str | None: protocols = request["subprotocols"] intersection = set(protocols) & set(self.protocols) sorted_intersection = sorted(intersection, key=protocols.index) return next(iter(sorted_intersection), None) async def create_websocket_response( - self, request: WebSocket, subprotocol: Optional[str] + self, request: WebSocket, subprotocol: str | None ) -> WebSocket: await request.accept(subprotocol=subprotocol) return request diff --git a/strawberry/asgi/test/client.py b/strawberry/asgi/test/client.py index 75a073ce38..f75246ef34 100644 --- a/strawberry/asgi/test/client.py +++ b/strawberry/asgi/test/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from strawberry.test import BaseGraphQLTestClient @@ -14,8 +14,8 @@ class GraphQLTestClient(BaseGraphQLTestClient): def _build_body( self, query: str, - variables: Optional[dict[str, Mapping]] = None, - files: Optional[dict[str, object]] = None, + variables: dict[str, Mapping] | None = None, + files: dict[str, object] | None = None, ) -> dict[str, object]: body: dict[str, object] = {"query": query} @@ -36,8 +36,8 @@ def _build_body( def request( self, body: dict[str, object], - headers: Optional[dict[str, object]] = None, - files: Optional[dict[str, object]] = None, + headers: dict[str, object] | None = None, + files: dict[str, object] | None = None, ) -> Any: return self._client.post( self.url, diff --git a/strawberry/chalice/views.py b/strawberry/chalice/views.py index 7c2449c632..9be13a1c79 100644 --- a/strawberry/chalice/views.py +++ b/strawberry/chalice/views.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from lia import ChaliceHTTPRequestAdapter, HTTPException @@ -25,8 +25,8 @@ class GraphQLView( def __init__( self, schema: BaseSchema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, ) -> None: self.allow_queries_via_get = allow_queries_via_get @@ -41,7 +41,7 @@ def __init__( else: self.graphql_ide = graphql_ide - def get_root_value(self, request: Request) -> Optional[RootValue]: + def get_root_value(self, request: Request) -> RootValue | None: return None def render_graphql_ide(self, request: Request) -> Response: @@ -58,7 +58,7 @@ def get_context(self, request: Request, response: TemporalResponse) -> Context: def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: TemporalResponse, ) -> Response: status_code = 200 diff --git a/strawberry/channels/handlers/base.py b/strawberry/channels/handlers/base.py index 07ce11aa65..5eb5ffe534 100644 --- a/strawberry/channels/handlers/base.py +++ b/strawberry/channels/handlers/base.py @@ -6,7 +6,6 @@ from typing import ( Any, Literal, - Optional, ) from typing_extensions import Protocol, TypedDict from weakref import WeakSet @@ -54,7 +53,7 @@ class ChannelsConsumer(AsyncConsumer): """Base channels async consumer.""" channel_name: str - channel_layer: Optional[ChannelsLayer] + channel_layer: ChannelsLayer | None channel_receive: Callable[[], Awaitable[dict]] def __init__(self, *args: str, **kwargs: Any) -> None: @@ -80,7 +79,7 @@ async def channel_listen( self, type: str, *, - timeout: Optional[float] = None, + timeout: float | None = None, groups: Sequence[str] = (), ) -> AsyncGenerator[Any, None]: """Listen for messages sent to this consumer. @@ -139,7 +138,7 @@ async def listen_to_channel( self, type: str, *, - timeout: Optional[float] = None, + timeout: float | None = None, groups: Sequence[str] = (), ) -> AsyncGenerator[Any, None]: """Listen for messages sent to this consumer. @@ -188,7 +187,7 @@ async def listen_to_channel( await self.channel_layer.group_discard(group, self.channel_name) async def _listen_to_channel_generator( - self, queue: asyncio.Queue, timeout: Optional[float] + self, queue: asyncio.Queue, timeout: float | None ) -> AsyncGenerator[Any, None]: """Generator for listen_to_channel method. diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 88290c9b72..974e0310c5 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -5,7 +5,7 @@ import warnings from functools import cached_property from io import BytesIO -from typing import TYPE_CHECKING, Any, Optional, TypeGuard, Union +from typing import TYPE_CHECKING, Any, TypeGuard from typing_extensions import assert_never from urllib.parse import parse_qs @@ -77,7 +77,7 @@ def method(self) -> HTTPMethod: return self.consumer.scope["method"].upper() @property - def content_type(self) -> Optional[str]: + def content_type(self) -> str | None: return self.headers.get("content-type", None) @cached_property @@ -123,7 +123,7 @@ def headers(self) -> Mapping[str, str]: return self.request.headers @property - def content_type(self) -> Optional[str]: + def content_type(self) -> str | None: return self.request.content_type @property @@ -163,7 +163,7 @@ def body(self) -> bytes: return self.request.body @property - def post_data(self) -> Mapping[str, Union[str, bytes]]: + def post_data(self) -> Mapping[str, str | bytes]: return self.request.form_data.form @property @@ -176,13 +176,13 @@ def get_form_data(self) -> FormData: class BaseGraphQLHTTPConsumer(ChannelsConsumer, AsyncHttpConsumer): graphql_ide_html: str - graphql_ide: Optional[GraphQL_IDE] = "graphiql" + graphql_ide: GraphQL_IDE | None = "graphiql" def __init__( self, schema: BaseSchema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, **kwargs: Any, @@ -205,7 +205,7 @@ def __init__( def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: TemporalResponse, ) -> ChannelsResponse: return ChannelsResponse( @@ -247,7 +247,7 @@ class GraphQLHTTPConsumer( BaseGraphQLHTTPConsumer, AsyncBaseHTTPView[ ChannelsRequest, - Union[ChannelsResponse, MultipartChannelsResponse], + ChannelsResponse | MultipartChannelsResponse, TemporalResponse, ChannelsRequest, TemporalResponse, @@ -279,7 +279,7 @@ class GraphQLHTTPConsumer( allow_queries_via_get: bool = True request_adapter_class = ChannelsRequestAdapter - async def get_root_value(self, request: ChannelsRequest) -> Optional[RootValue]: + async def get_root_value(self, request: ChannelsRequest) -> RootValue | None: return None # pragma: no cover async def get_context( @@ -322,13 +322,11 @@ def is_websocket_request( ) -> TypeGuard[ChannelsRequest]: return False - async def pick_websocket_subprotocol( - self, request: ChannelsRequest - ) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: ChannelsRequest) -> str | None: return None async def create_websocket_response( - self, request: ChannelsRequest, subprotocol: Optional[str] + self, request: ChannelsRequest, subprotocol: str | None ) -> TemporalResponse: raise NotImplementedError @@ -353,7 +351,7 @@ class SyncGraphQLHTTPConsumer( allow_queries_via_get: bool = True request_adapter_class = SyncChannelsRequestAdapter - def get_root_value(self, request: ChannelsRequest) -> Optional[RootValue]: + def get_root_value(self, request: ChannelsRequest) -> RootValue | None: return None # pragma: no cover def get_context( @@ -381,7 +379,7 @@ def run( self, request: ChannelsRequest, context: Context = UNSET, - root_value: Optional[RootValue] = UNSET, + root_value: RootValue | None = UNSET, ) -> ChannelsResponse | MultipartChannelsResponse: return super().run(request, context, root_value) diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index c39e2b2933..408443f528 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -5,10 +5,8 @@ import json from typing import ( TYPE_CHECKING, - Optional, TypedDict, TypeGuard, - Union, ) from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter @@ -62,7 +60,7 @@ async def close(self, code: int, reason: str) -> None: class MessageQueueData(TypedDict): - message: Union[str, None] + message: str | None disconnected: bool @@ -113,7 +111,7 @@ def __init__( GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL, ), - connection_init_wait_timeout: Optional[datetime.timedelta] = None, + connection_init_wait_timeout: datetime.timedelta | None = None, ) -> None: if connection_init_wait_timeout is None: connection_init_wait_timeout = datetime.timedelta(minutes=1) @@ -123,7 +121,7 @@ def __init__( self.keep_alive_interval = keep_alive_interval self.protocols = subscription_protocols self.message_queue: asyncio.Queue[MessageQueueData] = asyncio.Queue() - self.run_task: Optional[asyncio.Task] = None + self.run_task: asyncio.Task | None = None super().__init__() @@ -131,7 +129,7 @@ async def connect(self) -> None: self.run_task = asyncio.create_task(self.run(self)) async def receive( - self, text_data: Optional[str] = None, bytes_data: Optional[bytes] = None + self, text_data: str | None = None, bytes_data: bytes | None = None ) -> None: if text_data: self.message_queue.put_nowait({"message": text_data, "disconnected": False}) @@ -143,7 +141,7 @@ async def disconnect(self, code: int) -> None: assert self.run_task await self.run_task - async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]: + async def get_root_value(self, request: GraphQLWSConsumer) -> RootValue | None: return None async def get_context( @@ -163,7 +161,7 @@ async def get_sub_response(self, request: GraphQLWSConsumer) -> GraphQLWSConsume def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: GraphQLWSConsumer, ) -> GraphQLWSConsumer: raise NotImplementedError @@ -178,14 +176,14 @@ def is_websocket_request( async def pick_websocket_subprotocol( self, request: GraphQLWSConsumer - ) -> Optional[str]: + ) -> str | None: protocols = request.scope["subprotocols"] intersection = set(protocols) & set(self.protocols) sorted_intersection = sorted(intersection, key=protocols.index) return next(iter(sorted_intersection), None) async def create_websocket_response( - self, request: GraphQLWSConsumer, subprotocol: Optional[str] + self, request: GraphQLWSConsumer, subprotocol: str | None ) -> GraphQLWSConsumer: await request.accept(subprotocol=subprotocol) return request diff --git a/strawberry/channels/router.py b/strawberry/channels/router.py index 414902cd40..1c9e1a022c 100644 --- a/strawberry/channels/router.py +++ b/strawberry/channels/router.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from django.urls import re_path @@ -47,7 +47,7 @@ class GraphQLProtocolTypeRouter(ProtocolTypeRouter): def __init__( self, schema: BaseSchema, - django_application: Optional[str] = None, + django_application: str | None = None, url_pattern: str = "^graphql", ) -> None: http_urls = [re_path(url_pattern, GraphQLHTTPConsumer.as_asgi(schema=schema))] diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index f1807ca52b..4844fadfdc 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -4,8 +4,6 @@ from typing import ( TYPE_CHECKING, Any, - Optional, - Union, ) from graphql import GraphQLError, GraphQLFormattedError @@ -53,7 +51,7 @@ def __init__( self, application: ASGIApplication, path: str, - headers: Optional[list[tuple[bytes, bytes]]] = None, + headers: list[tuple[bytes, bytes]] | None = None, protocol: str = GRAPHQL_TRANSPORT_WS_PROTOCOL, connection_params: dict | None = None, **kwargs: Any, @@ -83,9 +81,9 @@ async def __aenter__(self) -> Self: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: await self.disconnect() @@ -114,8 +112,8 @@ async def gql_init(self) -> None: # get transformed into `FormattedExecutionResult` on the wire, but we attempt # to do a limited representation of them here, to make testing simpler. async def subscribe( - self, query: str, variables: Optional[dict] = None - ) -> Union[ExecutionResult, AsyncIterator[ExecutionResult]]: + self, query: str, variables: dict | None = None + ) -> ExecutionResult | AsyncIterator[ExecutionResult]: id_ = uuid.uuid4().hex if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: diff --git a/strawberry/cli/commands/codegen.py b/strawberry/cli/commands/codegen.py index f13f06a404..a1f62ab847 100644 --- a/strawberry/cli/commands/codegen.py +++ b/strawberry/cli/commands/codegen.py @@ -4,7 +4,7 @@ import importlib import inspect from pathlib import Path # noqa: TC003 -from typing import Optional, Union, cast +from typing import cast import rich import typer @@ -22,9 +22,9 @@ def _is_codegen_plugin(obj: object) -> bool: ) -def _import_plugin(plugin: str) -> Optional[type[QueryCodegenPlugin]]: +def _import_plugin(plugin: str) -> type[QueryCodegenPlugin] | None: module_name = plugin - symbol_name: Optional[str] = None + symbol_name: str | None = None if ":" in plugin: module_name, symbol_name = plugin.split(":", 1) @@ -61,7 +61,7 @@ def _import_plugin(plugin: str) -> Optional[type[QueryCodegenPlugin]]: @functools.lru_cache def _load_plugin( plugin_path: str, -) -> type[Union[QueryCodegenPlugin, ConsolePlugin]]: +) -> type[QueryCodegenPlugin | ConsolePlugin]: # try to import plugin_name from current folder # then try to import from strawberry.codegen.plugins @@ -79,7 +79,7 @@ def _load_plugin( def _load_plugins( plugin_ids: list[str], query: Path -) -> list[Union[QueryCodegenPlugin, ConsolePlugin]]: +) -> list[QueryCodegenPlugin | ConsolePlugin]: plugins = [] for ptype_id in plugin_ids: ptype = _load_plugin(ptype_id) @@ -91,7 +91,7 @@ def _load_plugins( @app.command(help="Generate code from a query") def codegen( - query: Optional[list[Path]] = typer.Argument( + query: list[Path] | None = typer.Argument( default=None, exists=True, dir_okay=False ), schema: str = typer.Option(..., help="Python path to the schema file"), @@ -120,7 +120,7 @@ def codegen( "-p", "--plugins", ), - cli_plugin: Optional[str] = None, + cli_plugin: str | None = None, ) -> None: if not query: return diff --git a/strawberry/cli/commands/schema_codegen.py b/strawberry/cli/commands/schema_codegen.py index 16a7d4194a..af252ba413 100644 --- a/strawberry/cli/commands/schema_codegen.py +++ b/strawberry/cli/commands/schema_codegen.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional import typer @@ -10,7 +9,7 @@ @app.command(help="Generate code from a query") def schema_codegen( schema: Path = typer.Argument(exists=True), - output: Optional[Path] = typer.Option( + output: Path | None = typer.Option( None, "-o", "--output", diff --git a/strawberry/cli/commands/upgrade/_run_codemod.py b/strawberry/cli/commands/upgrade/_run_codemod.py index c6fc2228d3..0f14cb14da 100644 --- a/strawberry/cli/commands/upgrade/_run_codemod.py +++ b/strawberry/cli/commands/upgrade/_run_codemod.py @@ -4,7 +4,7 @@ import os from importlib.metadata import version from multiprocessing import Pool, cpu_count -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, TypeAlias from libcst.codemod._cli import ExecutionConfig, ExecutionResult, _execute_transform from rich.progress import Progress @@ -16,7 +16,7 @@ from libcst.codemod import Codemod -ProgressType = Union[type[Progress], type[FakeProgress]] +ProgressType: TypeAlias = type[Progress] | type[FakeProgress] def _get_libcst_version() -> tuple[int, int, int]: diff --git a/strawberry/codegen/plugins/print_operation.py b/strawberry/codegen/plugins/print_operation.py index 5a37c87086..ed7db9c685 100644 --- a/strawberry/codegen/plugins/print_operation.py +++ b/strawberry/codegen/plugins/print_operation.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from strawberry.codegen import CodegenFile, QueryCodegenPlugin from strawberry.codegen.types import ( @@ -93,7 +93,7 @@ def _print_operation_variables(self, operation: GraphQLOperation) -> str: return f"({variables})" def _print_graphql_type( - self, type: GraphQLType, parent_type: Optional[GraphQLType] = None + self, type: GraphQLType, parent_type: GraphQLType | None = None ) -> str: if isinstance(type, GraphQLOptional): return self._print_graphql_type(type.of_type, type) diff --git a/strawberry/codegen/plugins/python.py b/strawberry/codegen/plugins/python.py index 4af4cb0cbf..c5f1bc7b90 100644 --- a/strawberry/codegen/plugins/python.py +++ b/strawberry/codegen/plugins/python.py @@ -3,7 +3,7 @@ import textwrap from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar from strawberry.codegen import CodegenFile, QueryCodegenPlugin from strawberry.codegen.types import ( @@ -31,7 +31,7 @@ @dataclass class PythonType: type: str - module: Optional[str] = None + module: str | None = None class PythonPlugin(QueryCodegenPlugin): diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 9ab7d1f97c..0c7226c91c 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -8,8 +8,6 @@ from typing import ( TYPE_CHECKING, Any, - Optional, - Union, cast, ) from typing_extensions import Protocol @@ -121,7 +119,7 @@ def write(self, folder: Path) -> None: class HasSelectionSet(Protocol): - selection_set: Optional[SelectionSetNode] + selection_set: SelectionSetNode | None class QueryCodegenPlugin: @@ -233,7 +231,7 @@ class QueryCodegenPluginManager: def __init__( self, plugins: list[QueryCodegenPlugin], - console_plugin: Optional[ConsolePlugin] = None, + console_plugin: ConsolePlugin | None = None, ) -> None: self.plugins = plugins self.console_plugin = console_plugin @@ -297,7 +295,7 @@ def __init__( self, schema: Schema, plugins: list[QueryCodegenPlugin], - console_plugin: Optional[ConsolePlugin] = None, + console_plugin: ConsolePlugin | None = None, ) -> None: self.schema = schema self.plugin_manager = QueryCodegenPluginManager(plugins, console_plugin) @@ -389,7 +387,7 @@ def _convert_selection(self, selection: SelectionNode) -> GraphQLSelection: raise ValueError(f"Unsupported type: {type(selection)}") # pragma: no cover def _convert_selection_set( - self, selection_set: Optional[SelectionSetNode] + self, selection_set: SelectionSetNode | None ) -> list[GraphQLSelection]: if selection_set is None: return [] @@ -488,9 +486,9 @@ def _convert_operation( def _convert_variable_definitions( self, - variable_definitions: Optional[Iterable[VariableDefinitionNode]], + variable_definitions: Iterable[VariableDefinitionNode] | None, operation_name: str, - ) -> tuple[list[GraphQLVariable], Optional[GraphQLObjectType]]: + ) -> tuple[list[GraphQLVariable], GraphQLObjectType | None]: if not variable_definitions: return [], None @@ -521,7 +519,7 @@ def _get_operations(self, ast: DocumentNode) -> list[OperationDefinitionNode]: def _get_field_type( self, - field_type: Union[StrawberryType, type], + field_type: StrawberryType | type, ) -> GraphQLType: if isinstance(field_type, StrawberryOptional): return GraphQLOptional(self._get_field_type(field_type.of_type)) @@ -551,7 +549,7 @@ def _get_field_type( raise ValueError(f"Unsupported type: {field_type}") # pragma: no cover def _collect_type_from_strawberry_type( - self, strawberry_type: Union[type, StrawberryType] + self, strawberry_type: type | StrawberryType ) -> GraphQLType: type_: GraphQLType @@ -590,9 +588,9 @@ def _collect_type_from_strawberry_type( return type_ def _collect_type_from_variable( - self, variable_type: TypeNode, parent_type: Optional[TypeNode] = None + self, variable_type: TypeNode, parent_type: TypeNode | None = None ) -> GraphQLType: - type_: Optional[GraphQLType] = None + type_: GraphQLType | None = None if isinstance(variable_type, ListTypeNode): type_ = GraphQLList( @@ -631,11 +629,9 @@ def _field_from_selection( ) def _unwrap_type( - self, type_: Union[type, StrawberryType] - ) -> tuple[ - Union[type, StrawberryType], Optional[Callable[[GraphQLType], GraphQLType]] - ]: - wrapper: Optional[Callable[[GraphQLType], GraphQLType]] = None + self, type_: type | StrawberryType + ) -> tuple[type | StrawberryType, Callable[[GraphQLType], GraphQLType] | None]: + wrapper: Callable[[GraphQLType], GraphQLType] | None = None if isinstance(type_, StrawberryOptional): type_, previous_wrapper = self._unwrap_type(type_.of_type) @@ -732,7 +728,7 @@ def _collect_types_with_inline_fragments( selection: HasSelectionSet, parent_type: StrawberryObjectDefinition, class_name: str, - ) -> Union[GraphQLObjectType, GraphQLUnion]: + ) -> GraphQLObjectType | GraphQLUnion: sub_types = self._collect_types_using_fragments( selection, parent_type, class_name ) @@ -767,7 +763,7 @@ def _collect_types( ) current_type = graph_ql_object_type_factory(class_name) - fields: list[Union[GraphQLFragmentSpread, GraphQLField]] = [] + fields: list[GraphQLFragmentSpread | GraphQLField] = [] for sub_selection in selection_set.selections: if isinstance(sub_selection, FragmentSpreadNode): @@ -838,7 +834,7 @@ def _collect_types_using_fragments( list(common_fields), graphql_typename=type_condition_name, ) - fields: list[Union[GraphQLFragmentSpread, GraphQLField]] = [] + fields: list[GraphQLFragmentSpread | GraphQLField] = [] for sub_selection in fragment.selection_set.selections: if isinstance(sub_selection, FragmentSpreadNode): @@ -893,7 +889,7 @@ def _collect_types_using_fragments( return sub_types def _collect_scalar( - self, scalar_definition: ScalarDefinition, python_type: Optional[type] + self, scalar_definition: ScalarDefinition, python_type: type | None ) -> GraphQLScalar: graphql_scalar = GraphQLScalar(scalar_definition.name, python_type=python_type) diff --git a/strawberry/codegen/types.py b/strawberry/codegen/types.py index 2279dd16c4..837a0ab650 100644 --- a/strawberry/codegen/types.py +++ b/strawberry/codegen/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, TypeAlias if TYPE_CHECKING: from collections.abc import Mapping @@ -30,9 +30,9 @@ class GraphQLUnion: @dataclass class GraphQLField: name: str - alias: Optional[str] + alias: str | None type: GraphQLType - default_value: Optional[GraphQLArgumentValue] = None + default_value: GraphQLArgumentValue | None = None @dataclass @@ -44,7 +44,7 @@ class GraphQLFragmentSpread: class GraphQLObjectType: name: str fields: list[GraphQLField] = field(default_factory=list) - graphql_typename: Optional[str] = None + graphql_typename: str | None = None # Subtype of GraphQLObjectType. @@ -54,7 +54,7 @@ class GraphQLObjectType: class GraphQLFragmentType(GraphQLObjectType): name: str fields: list[GraphQLField] = field(default_factory=list) - graphql_typename: Optional[str] = None + graphql_typename: str | None = None on: str = "" def __post_init__(self) -> None: @@ -74,23 +74,23 @@ class GraphQLEnum: @dataclass class GraphQLScalar: name: str - python_type: Optional[type] + python_type: type | None -GraphQLType = Union[ - GraphQLObjectType, - GraphQLEnum, - GraphQLScalar, - GraphQLOptional, - GraphQLList, - GraphQLUnion, -] +GraphQLType: TypeAlias = ( + GraphQLObjectType + | GraphQLEnum + | GraphQLScalar + | GraphQLOptional + | GraphQLList + | GraphQLUnion +) @dataclass class GraphQLFieldSelection: field: str - alias: Optional[str] + alias: str | None selections: list[GraphQLSelection] directives: list[GraphQLDirective] arguments: list[GraphQLArgument] @@ -102,9 +102,9 @@ class GraphQLInlineFragment: selections: list[GraphQLSelection] -GraphQLSelection = Union[ - GraphQLFieldSelection, GraphQLInlineFragment, GraphQLFragmentSpread -] +GraphQLSelection: TypeAlias = ( + GraphQLFieldSelection | GraphQLInlineFragment | GraphQLFragmentSpread +) @dataclass @@ -125,7 +125,7 @@ class GraphQLFloatValue: @dataclass class GraphQLEnumValue: name: str - enum_type: Optional[str] = None + enum_type: str | None = None @dataclass @@ -155,17 +155,17 @@ class GraphQLVariableReference: value: str -GraphQLArgumentValue = Union[ - GraphQLStringValue, - GraphQLNullValue, - GraphQLIntValue, - GraphQLVariableReference, - GraphQLFloatValue, - GraphQLListValue, - GraphQLEnumValue, - GraphQLBoolValue, - GraphQLObjectValue, -] +GraphQLArgumentValue: TypeAlias = ( + GraphQLStringValue + | GraphQLNullValue + | GraphQLIntValue + | GraphQLVariableReference + | GraphQLFloatValue + | GraphQLListValue + | GraphQLEnumValue + | GraphQLBoolValue + | GraphQLObjectValue +) @dataclass @@ -194,7 +194,7 @@ class GraphQLOperation: directives: list[GraphQLDirective] variables: list[GraphQLVariable] type: GraphQLObjectType - variables_type: Optional[GraphQLObjectType] + variables_type: GraphQLObjectType | None __all__ = [ diff --git a/strawberry/codemods/annotated_unions.py b/strawberry/codemods/annotated_unions.py index e60ad601dd..30cbc6864d 100644 --- a/strawberry/codemods/annotated_unions.py +++ b/strawberry/codemods/annotated_unions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import libcst as cst import libcst.matchers as m @@ -50,7 +50,7 @@ def __init__( super().__init__(context) - def visit_Module(self, node: cst.Module) -> Optional[bool]: # noqa: N802 + def visit_Module(self, node: cst.Module) -> bool | None: # noqa: N802 self._is_using_named_import = False return super().visit_Module(node) diff --git a/strawberry/dataloader.py b/strawberry/dataloader.py index 4ccf310f82..c37bb51231 100644 --- a/strawberry/dataloader.py +++ b/strawberry/dataloader.py @@ -9,9 +9,7 @@ TYPE_CHECKING, Any, Generic, - Optional, TypeVar, - Union, overload, ) @@ -54,7 +52,7 @@ def __len__(self) -> int: class AbstractCache(Generic[K, T], ABC): @abstractmethod - def get(self, key: K) -> Union[Future[T], None]: + def get(self, key: K) -> Future[T] | None: pass @abstractmethod @@ -71,13 +69,13 @@ def clear(self) -> None: class DefaultCache(AbstractCache[K, T]): - def __init__(self, cache_key_fn: Optional[Callable[[K], Hashable]] = None) -> None: + def __init__(self, cache_key_fn: Callable[[K], Hashable] | None = None) -> None: self.cache_key_fn: Callable[[K], Hashable] = ( cache_key_fn if cache_key_fn is not None else lambda x: x ) self.cache_map: dict[Hashable, Future[T]] = {} - def get(self, key: K) -> Union[Future[T], None]: + def get(self, key: K) -> Future[T] | None: return self.cache_map.get(self.cache_key_fn(key)) def set(self, key: K, value: Future[T]) -> None: @@ -91,7 +89,7 @@ def clear(self) -> None: class DataLoader(Generic[K, T]): - batch: Optional[Batch[K, T]] = None + batch: Batch[K, T] | None = None cache: bool = False cache_map: AbstractCache[K, T] @@ -99,12 +97,12 @@ class DataLoader(Generic[K, T]): def __init__( self, # any BaseException is rethrown in 'load', so should be excluded from the T type - load_fn: Callable[[list[K]], Awaitable[Sequence[Union[T, BaseException]]]], - max_batch_size: Optional[int] = None, + load_fn: Callable[[list[K]], Awaitable[Sequence[T | BaseException]]], + max_batch_size: int | None = None, cache: bool = True, - loop: Optional[AbstractEventLoop] = None, - cache_map: Optional[AbstractCache[K, T]] = None, - cache_key_fn: Optional[Callable[[K], Hashable]] = None, + loop: AbstractEventLoop | None = None, + cache_map: AbstractCache[K, T] | None = None, + cache_key_fn: Callable[[K], Hashable] | None = None, ) -> None: ... # fallback if load_fn is untyped and there's no other info for inference @@ -112,21 +110,21 @@ def __init__( def __init__( self: DataLoader[K, Any], load_fn: Callable[[list[K]], Awaitable[list[Any]]], - max_batch_size: Optional[int] = None, + max_batch_size: int | None = None, cache: bool = True, - loop: Optional[AbstractEventLoop] = None, - cache_map: Optional[AbstractCache[K, T]] = None, - cache_key_fn: Optional[Callable[[K], Hashable]] = None, + loop: AbstractEventLoop | None = None, + cache_map: AbstractCache[K, T] | None = None, + cache_key_fn: Callable[[K], Hashable] | None = None, ) -> None: ... def __init__( self, - load_fn: Callable[[list[K]], Awaitable[Sequence[Union[T, BaseException]]]], - max_batch_size: Optional[int] = None, + load_fn: Callable[[list[K]], Awaitable[Sequence[T | BaseException]]], + max_batch_size: int | None = None, cache: bool = True, - loop: Optional[AbstractEventLoop] = None, - cache_map: Optional[AbstractCache[K, T]] = None, - cache_key_fn: Optional[Callable[[K], Hashable]] = None, + loop: AbstractEventLoop | None = None, + cache_map: AbstractCache[K, T] | None = None, + cache_key_fn: Callable[[K], Hashable] | None = None, ): self.load_fn = load_fn self.max_batch_size = max_batch_size diff --git a/strawberry/directive.py b/strawberry/directive.py index 95b134561f..94005ceb24 100644 --- a/strawberry/directive.py +++ b/strawberry/directive.py @@ -7,7 +7,6 @@ Annotated, Any, Generic, - Optional, TypeVar, ) @@ -82,17 +81,17 @@ class StrawberryDirectiveResolver(StrawberryResolver[T]): ) @cached_property - def value_parameter(self) -> Optional[inspect.Parameter]: + def value_parameter(self) -> inspect.Parameter | None: return self.reserved_parameters.get(VALUE_PARAMSPEC) @dataclasses.dataclass class StrawberryDirective(Generic[T]): python_name: str - graphql_name: Optional[str] + graphql_name: str | None resolver: StrawberryDirectiveResolver[T] locations: list[DirectiveLocation] - description: Optional[str] = None + description: str | None = None @cached_property def arguments(self) -> list[StrawberryArgument]: @@ -102,8 +101,8 @@ def arguments(self) -> list[StrawberryArgument]: def directive( *, locations: list[DirectiveLocation], - description: Optional[str] = None, - name: Optional[str] = None, + description: str | None = None, + name: str | None = None, ) -> Callable[[Callable[..., T]], StrawberryDirective[T]]: """Decorator to create a GraphQL operation directive. diff --git a/strawberry/django/test/client.py b/strawberry/django/test/client.py index b9d5c994fb..d4785df1df 100644 --- a/strawberry/django/test/client.py +++ b/strawberry/django/test/client.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from strawberry.test import BaseGraphQLTestClient @@ -7,8 +7,8 @@ class GraphQLTestClient(BaseGraphQLTestClient): def request( self, body: dict[str, object], - headers: Optional[dict[str, object]] = None, - files: Optional[dict[str, object]] = None, + headers: dict[str, object] | None = None, + files: dict[str, object] | None = None, ) -> Any: if files: return self._client.post( diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 013bdb9f06..8e74889fbd 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -5,9 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, TypeGuard, - Union, ) from asgiref.sync import markcoroutinefunction @@ -47,7 +45,7 @@ # TODO: remove this and unify temporal responses class TemporalHttpResponse(JsonResponse): - status_code: Optional[int] = None # pyright: ignore + status_code: int | None = None # pyright: ignore def __init__(self) -> None: super().__init__({}) @@ -70,8 +68,8 @@ class BaseView: def __init__( self, schema: BaseSchema, - graphiql: Optional[str] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: str | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, **kwargs: Any, @@ -94,7 +92,7 @@ def __init__( def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: HttpResponse, ) -> HttpResponseBase: data = self.encode_json(response_data) @@ -142,13 +140,13 @@ class GraphQLView( ], View, ): - graphiql: Optional[bool] = None - graphql_ide: Optional[GraphQL_IDE] = "graphiql" + graphiql: bool | None = None + graphql_ide: GraphQL_IDE | None = "graphiql" allow_queries_via_get = True schema: BaseSchema = None # type: ignore request_adapter_class = DjangoHTTPRequestAdapter - def get_root_value(self, request: HttpRequest) -> Optional[RootValue]: + def get_root_value(self, request: HttpRequest) -> RootValue | None: return None def get_context(self, request: HttpRequest, response: HttpResponse) -> Context: @@ -159,7 +157,7 @@ def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: def dispatch( self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: + ) -> HttpResponseNotAllowed | TemplateResponse | HttpResponseBase: try: return self.run(request=request) except HTTPException as e: @@ -190,8 +188,8 @@ class AsyncGraphQLView( ], View, ): - graphiql: Optional[bool] = None - graphql_ide: Optional[GraphQL_IDE] = "graphiql" + graphiql: bool | None = None + graphql_ide: GraphQL_IDE | None = "graphiql" allow_queries_via_get = True schema: BaseSchema = None # type: ignore request_adapter_class = AsyncDjangoHTTPRequestAdapter @@ -206,7 +204,7 @@ def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]: # noqa: N80 return view - async def get_root_value(self, request: HttpRequest) -> Optional[RootValue]: + async def get_root_value(self, request: HttpRequest) -> RootValue | None: return None async def get_context( @@ -219,7 +217,7 @@ async def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: async def dispatch( # pyright: ignore self, request: HttpRequest, *args: Any, **kwargs: Any - ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: + ) -> HttpResponseNotAllowed | TemplateResponse | HttpResponseBase: try: return await self.run(request=request) except HTTPException as e: @@ -239,11 +237,11 @@ async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: def is_websocket_request(self, request: HttpRequest) -> TypeGuard[HttpRequest]: return False - async def pick_websocket_subprotocol(self, request: HttpRequest) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: HttpRequest) -> str | None: raise NotImplementedError async def create_websocket_response( - self, request: HttpRequest, subprotocol: Optional[str] + self, request: HttpRequest, subprotocol: str | None ) -> TemporalHttpResponse: raise NotImplementedError diff --git a/strawberry/exceptions/__init__.py b/strawberry/exceptions/__init__.py index 0155877e80..de701c5c12 100644 --- a/strawberry/exceptions/__init__.py +++ b/strawberry/exceptions/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from graphql import GraphQLError @@ -72,20 +72,20 @@ def __init__(self, annotation: GraphQLInputObjectType) -> None: class MissingTypesForGenericError(Exception): """Raised when a generic types was used without passing any type.""" - def __init__(self, annotation: Union[StrawberryType, type]) -> None: + def __init__(self, annotation: StrawberryType | type) -> None: message = f'The type "{annotation!r}" is generic, but no type has been passed' super().__init__(message) class UnsupportedTypeError(StrawberryException): - def __init__(self, annotation: Union[StrawberryType, type]) -> None: + def __init__(self, annotation: StrawberryType | type) -> None: message = f"{annotation} conversion is not supported" super().__init__(message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: return None diff --git a/strawberry/exceptions/conflicting_arguments.py b/strawberry/exceptions/conflicting_arguments.py index 8315363606..27c6b8087b 100644 --- a/strawberry/exceptions/conflicting_arguments.py +++ b/strawberry/exceptions/conflicting_arguments.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -43,7 +43,7 @@ def argument_names_str(self) -> str: ) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.function is None: return None # pragma: no cover diff --git a/strawberry/exceptions/duplicated_type_name.py b/strawberry/exceptions/duplicated_type_name.py index be98c79f9c..662610e9e0 100644 --- a/strawberry/exceptions/duplicated_type_name.py +++ b/strawberry/exceptions/duplicated_type_name.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -17,8 +17,8 @@ class DuplicatedTypeName(StrawberryException): def __init__( self, - first_cls: Optional[type], - second_cls: Optional[type], + first_cls: type | None, + second_cls: type | None, duplicated_type_name: str, ) -> None: self.first_cls = first_cls @@ -66,7 +66,7 @@ def __rich_body__(self) -> RenderableType: ) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.first_cls is None: return None # pragma: no cover diff --git a/strawberry/exceptions/exception.py b/strawberry/exceptions/exception.py index 7bfc8e2443..eb5eaac127 100644 --- a/strawberry/exceptions/exception.py +++ b/strawberry/exceptions/exception.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from strawberry.utils.str_converters import to_kebab_case @@ -40,7 +40,7 @@ def documentation_url(self) -> str: @cached_property @abstractmethod - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: return None @property @@ -61,7 +61,7 @@ def __rich_footer__(self) -> RenderableType: f"[link={self.documentation_url}]{self.documentation_url}" ).strip() - def __rich__(self) -> Optional[RenderableType]: + def __rich__(self) -> RenderableType | None: from rich.box import SIMPLE from rich.console import Group from rich.panel import Panel diff --git a/strawberry/exceptions/handler.py b/strawberry/exceptions/handler.py index 435cd89bb5..11ba0e670f 100644 --- a/strawberry/exceptions/handler.py +++ b/strawberry/exceptions/handler.py @@ -3,7 +3,7 @@ import threading from collections.abc import Callable from types import TracebackType -from typing import Any, Optional, cast +from typing import Any, cast from .exception import StrawberryException, UnableToFindExceptionSource @@ -11,7 +11,7 @@ ExceptionHandler = Callable[ - [type[BaseException], BaseException, Optional[TracebackType]], None + [type[BaseException], BaseException, TracebackType | None], None ] @@ -32,7 +32,7 @@ def _get_handler(exception_type: type[BaseException]) -> ExceptionHandler: def _handler( exception_type: type[BaseException], exception: BaseException, - traceback: Optional[TracebackType], + traceback: TracebackType | None, ) -> None: try: rich.print(exception) @@ -50,7 +50,7 @@ def _handler( def strawberry_exception_handler( exception_type: type[BaseException], exception: BaseException, - traceback: Optional[TracebackType], + traceback: TracebackType | None, ) -> None: _get_handler(exception_type)(exception_type, exception, traceback) @@ -58,9 +58,9 @@ def strawberry_exception_handler( def strawberry_threading_exception_handler( args: tuple[ type[BaseException], - Optional[BaseException], - Optional[TracebackType], - Optional[threading.Thread], + BaseException | None, + TracebackType | None, + threading.Thread | None, ], ) -> None: (exception_type, exception, traceback, _) = args diff --git a/strawberry/exceptions/invalid_argument_type.py b/strawberry/exceptions/invalid_argument_type.py index 092b8ad6d0..48be8e6835 100644 --- a/strawberry/exceptions/invalid_argument_type.py +++ b/strawberry/exceptions/invalid_argument_type.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from strawberry.types.base import get_object_definition @@ -55,7 +55,7 @@ def __init__( ) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.function is None: return None # pragma: no cover diff --git a/strawberry/exceptions/invalid_superclass_interface.py b/strawberry/exceptions/invalid_superclass_interface.py index 55b74b4659..c0c36d37e5 100644 --- a/strawberry/exceptions/invalid_superclass_interface.py +++ b/strawberry/exceptions/invalid_superclass_interface.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -36,6 +36,6 @@ def __init__( super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: source_finder = SourceFinder() return source_finder.find_class_from_object(self.cls) diff --git a/strawberry/exceptions/invalid_union_type.py b/strawberry/exceptions/invalid_union_type.py index 25139cfb33..5af48f4145 100644 --- a/strawberry/exceptions/invalid_union_type.py +++ b/strawberry/exceptions/invalid_union_type.py @@ -3,7 +3,7 @@ from functools import cached_property from inspect import getframeinfo, stack from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from strawberry.exceptions.utils.source_finder import SourceFinder @@ -24,7 +24,7 @@ def __init__( self, union_name: str, invalid_type: object, - union_definition: Optional[StrawberryUnion] = None, + union_definition: StrawberryUnion | None = None, ) -> None: from strawberry.types.base import StrawberryList from strawberry.types.scalar import ScalarWrapper @@ -58,7 +58,7 @@ def __init__( self.annotation_message = "invalid type here" @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: source_finder = SourceFinder() if self.union_definition: @@ -100,7 +100,7 @@ def __init__(self, union: StrawberryUnion, other: object) -> None: self.annotation_message = "invalid type here" @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: source_finder = SourceFinder() return source_finder.find_union_merge(self.union, self.other, frame=self.frame) diff --git a/strawberry/exceptions/missing_arguments_annotations.py b/strawberry/exceptions/missing_arguments_annotations.py index a5fd5cfd48..27eec5d916 100644 --- a/strawberry/exceptions/missing_arguments_annotations.py +++ b/strawberry/exceptions/missing_arguments_annotations.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -52,7 +52,7 @@ def missing_arguments_str(self) -> str: return f'arguments "{head}" and "{arguments[-1]}"' @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.function is None: return None # pragma: no cover diff --git a/strawberry/exceptions/missing_dependencies.py b/strawberry/exceptions/missing_dependencies.py index fa72f17ea3..d279286c9c 100644 --- a/strawberry/exceptions/missing_dependencies.py +++ b/strawberry/exceptions/missing_dependencies.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - class MissingOptionalDependenciesError(Exception): """Some optional dependencies that are required for a particular task are missing.""" @@ -9,8 +7,8 @@ class MissingOptionalDependenciesError(Exception): def __init__( self, *, - packages: Optional[list[str]] = None, - extras: Optional[list[str]] = None, + packages: list[str] | None = None, + extras: list[str] | None = None, ) -> None: """Initialize the error. diff --git a/strawberry/exceptions/missing_field_annotation.py b/strawberry/exceptions/missing_field_annotation.py index 594c575393..ba56d8423a 100644 --- a/strawberry/exceptions/missing_field_annotation.py +++ b/strawberry/exceptions/missing_field_annotation.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -32,7 +32,7 @@ def __init__(self, field_name: str, cls: type) -> None: super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.cls is None: return None # pragma: no cover diff --git a/strawberry/exceptions/missing_return_annotation.py b/strawberry/exceptions/missing_return_annotation.py index 4b863c84f7..d19f40df89 100644 --- a/strawberry/exceptions/missing_return_annotation.py +++ b/strawberry/exceptions/missing_return_annotation.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -37,7 +37,7 @@ def __init__( self.annotation_message = "resolver missing annotation" @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.function is None: return None # pragma: no cover diff --git a/strawberry/exceptions/object_is_not_a_class.py b/strawberry/exceptions/object_is_not_a_class.py index 0537e9a39a..f92f025249 100644 --- a/strawberry/exceptions/object_is_not_a_class.py +++ b/strawberry/exceptions/object_is_not_a_class.py @@ -2,7 +2,7 @@ from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -56,7 +56,7 @@ def type(cls, obj: object) -> ObjectIsNotClassError: return cls(obj, cls.MethodType.TYPE) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.function is None: return None # pragma: no cover diff --git a/strawberry/exceptions/object_is_not_an_enum.py b/strawberry/exceptions/object_is_not_an_enum.py index e3817dcb94..b6c6e29b8a 100644 --- a/strawberry/exceptions/object_is_not_an_enum.py +++ b/strawberry/exceptions/object_is_not_an_enum.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -31,7 +31,7 @@ def __init__(self, cls: type[Enum]) -> None: super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.cls is None: return None # pragma: no cover diff --git a/strawberry/exceptions/permission_fail_silently_requires_optional.py b/strawberry/exceptions/permission_fail_silently_requires_optional.py index 16bb1494cf..97b98b1838 100644 --- a/strawberry/exceptions/permission_fail_silently_requires_optional.py +++ b/strawberry/exceptions/permission_fail_silently_requires_optional.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -32,7 +32,7 @@ def __init__(self, field: StrawberryField) -> None: super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: origin = self.field.origin source_finder = SourceFinder() diff --git a/strawberry/exceptions/private_strawberry_field.py b/strawberry/exceptions/private_strawberry_field.py index 918cb64223..844fec26af 100644 --- a/strawberry/exceptions/private_strawberry_field.py +++ b/strawberry/exceptions/private_strawberry_field.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .exception import StrawberryException from .utils.source_finder import SourceFinder @@ -32,7 +32,7 @@ def __init__(self, field_name: str, cls: type) -> None: super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.cls is None: return None # pragma: no cover diff --git a/strawberry/exceptions/scalar_already_registered.py b/strawberry/exceptions/scalar_already_registered.py index c950b5d369..b74e1e4c59 100644 --- a/strawberry/exceptions/scalar_already_registered.py +++ b/strawberry/exceptions/scalar_already_registered.py @@ -2,7 +2,7 @@ from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from strawberry.exceptions.utils.source_finder import SourceFinder @@ -45,7 +45,7 @@ def __init__( super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if not all( (self.scalar_definition._source_file, self.scalar_definition._source_line) ): diff --git a/strawberry/exceptions/syntax.py b/strawberry/exceptions/syntax.py index 0403e07a2c..f5c171b39f 100644 --- a/strawberry/exceptions/syntax.py +++ b/strawberry/exceptions/syntax.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pygments.lexers import PythonLexer from rich.segment import Segment @@ -15,9 +15,9 @@ def __init__( self, code: str, line_range: tuple[int, int], - highlight_lines: Optional[set[int]] = None, + highlight_lines: set[int] | None = None, line_offset: int = 0, - line_annotations: Optional[dict[int, str]] = None, + line_annotations: dict[int, str] | None = None, ) -> None: self.line_offset = line_offset self.line_annotations = line_annotations or {} diff --git a/strawberry/exceptions/unresolved_field_type.py b/strawberry/exceptions/unresolved_field_type.py index 70eff59ed4..8468031492 100644 --- a/strawberry/exceptions/unresolved_field_type.py +++ b/strawberry/exceptions/unresolved_field_type.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from strawberry.exceptions.utils.source_finder import SourceFinder @@ -40,7 +40,7 @@ def __init__( super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: source_finder = SourceFinder() # field could be attached to the class or not diff --git a/strawberry/exceptions/utils/source_finder.py b/strawberry/exceptions/utils/source_finder.py index 550ded9acf..352b35f17e 100644 --- a/strawberry/exceptions/utils/source_finder.py +++ b/strawberry/exceptions/utils/source_finder.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from strawberry.exceptions.exception_source import ExceptionSource @@ -30,7 +30,7 @@ class LibCSTSourceFinder: def __init__(self) -> None: self.cst = importlib.import_module("libcst") - def find_source(self, module: str) -> Optional[SourcePath]: + def find_source(self, module: str) -> SourcePath | None: # TODO: support for pyodide source_module = sys.modules.get(module) @@ -73,11 +73,11 @@ def _find(self, source: str, matcher: Any) -> Sequence[CSTNode]: def _find_definition_by_qualname( self, qualname: str, nodes: Sequence[CSTNode] - ) -> Optional[CSTNode]: + ) -> CSTNode | None: from libcst import ClassDef, FunctionDef for definition in nodes: - parent: Optional[CSTNode] = definition + parent: CSTNode | None = definition stack = [] while parent: @@ -101,7 +101,7 @@ def _find_definition_by_qualname( def _find_function_definition( self, source: SourcePath, function: Callable[..., Any] - ) -> Optional[FunctionDef]: + ) -> FunctionDef | None: import libcst.matchers as m matcher = m.FunctionDef(name=m.Name(value=function.__name__)) @@ -115,7 +115,7 @@ def _find_function_definition( def _find_class_definition( self, source: SourcePath, cls: type[Any] - ) -> Optional[CSTNode]: + ) -> CSTNode | None: import libcst.matchers as m matcher = m.ClassDef(name=m.Name(value=cls.__name__)) @@ -123,7 +123,7 @@ def _find_class_definition( class_defs = self._find(source.code, matcher) return self._find_definition_by_qualname(cls.__qualname__, class_defs) - def find_class(self, cls: type[Any]) -> Optional[ExceptionSource]: + def find_class(self, cls: type[Any]) -> ExceptionSource | None: source = self.find_source(cls.__module__) if source is None: @@ -149,7 +149,7 @@ def find_class(self, cls: type[Any]) -> Optional[ExceptionSource]: def find_class_attribute( self, cls: type[Any], attribute_name: str - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: source = self.find_source(cls.__module__) if source is None: @@ -191,7 +191,7 @@ def find_class_attribute( error_column_end=attribute_position.end.column, ) - def find_function(self, function: Callable[..., Any]) -> Optional[ExceptionSource]: + def find_function(self, function: Callable[..., Any]) -> ExceptionSource | None: source = self.find_source(function.__module__) if source is None: @@ -225,7 +225,7 @@ def find_function(self, function: Callable[..., Any]) -> Optional[ExceptionSourc def find_argument( self, function: Callable[..., Any], argument_name: str - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: source = self.find_source(function.__module__) if source is None: @@ -263,7 +263,7 @@ def find_argument( def find_union_call( self, path: Path, union_name: str, invalid_type: object - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: import libcst.matchers as m source = path.read_text() @@ -339,7 +339,7 @@ def find_union_call( def find_union_merge( self, union: StrawberryUnion, other: object, frame: Traceback - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: import libcst.matchers as m path = Path(frame.filename) @@ -375,7 +375,7 @@ def find_union_merge( def find_annotated_union( self, union_definition: StrawberryUnion, invalid_type: object - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: if union_definition._source_file is None: return None @@ -503,7 +503,7 @@ def find_annotated_union( def find_scalar_call( self, scalar_definition: ScalarDefinition - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: if scalar_definition._source_file is None: return None # pragma: no cover @@ -571,7 +571,7 @@ def _create_scalar_exception_source( call_node: Any, scalar_definition: ScalarDefinition, is_newtype: bool, - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: """Helper method to create ExceptionSource for scalar calls.""" import libcst.matchers as m @@ -615,33 +615,33 @@ def _create_scalar_exception_source( class SourceFinder: @cached_property - def cst(self) -> Optional[LibCSTSourceFinder]: + def cst(self) -> LibCSTSourceFinder | None: try: return LibCSTSourceFinder() except ImportError: return None # pragma: no cover - def find_class_from_object(self, cls: type[Any]) -> Optional[ExceptionSource]: + def find_class_from_object(self, cls: type[Any]) -> ExceptionSource | None: return self.cst.find_class(cls) if self.cst else None def find_class_attribute_from_object( self, cls: type[Any], attribute_name: str - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: return self.cst.find_class_attribute(cls, attribute_name) if self.cst else None def find_function_from_object( self, function: Callable[..., Any] - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: return self.cst.find_function(function) if self.cst else None def find_argument_from_object( self, function: Callable[..., Any], argument_name: str - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: return self.cst.find_argument(function, argument_name) if self.cst else None def find_union_call( self, path: Path, union_name: str, invalid_type: object - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: return ( self.cst.find_union_call(path, union_name, invalid_type) if self.cst @@ -650,17 +650,17 @@ def find_union_call( def find_union_merge( self, union: StrawberryUnion, other: object, frame: Traceback - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: return self.cst.find_union_merge(union, other, frame) if self.cst else None def find_scalar_call( self, scalar_definition: ScalarDefinition - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: return self.cst.find_scalar_call(scalar_definition) if self.cst else None def find_annotated_union( self, union_definition: StrawberryUnion, invalid_type: object - ) -> Optional[ExceptionSource]: + ) -> ExceptionSource | None: return ( self.cst.find_annotated_union(union_definition, invalid_type) if self.cst diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 109c32c14e..c8d00f9aad 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -25,12 +25,12 @@ class CompatModelField: type_: Any outer_type_: Any default: Any - default_factory: Optional[Callable[[], Any]] + default_factory: Callable[[], Any] | None required: bool - alias: Optional[str] + alias: str | None allow_none: bool has_alias: bool - description: Optional[str] + description: str | None _missing_type: Any is_v1: bool @@ -44,8 +44,8 @@ def has_default(self) -> bool: ATTR_TO_TYPE_MAP = { - "NoneStr": Optional[str], - "NoneBytes": Optional[bytes], + "NoneStr": Optional[str], # noqa: UP045 + "NoneBytes": Optional[bytes], # noqa: UP045 "StrBytes": None, "NoneStrBytes": None, "StrictStr": str, diff --git a/strawberry/experimental/pydantic/conversion.py b/strawberry/experimental/pydantic/conversion.py index 3b05169fef..ed5cbd86ed 100644 --- a/strawberry/experimental/pydantic/conversion.py +++ b/strawberry/experimental/pydantic/conversion.py @@ -2,7 +2,7 @@ import copy import dataclasses -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast from strawberry.types.base import ( StrawberryList, @@ -18,7 +18,7 @@ def _convert_from_pydantic_to_strawberry_type( - type_: Union[StrawberryType, type], + type_: StrawberryType | type, data_from_model=None, # noqa: ANN001 extra=None, # noqa: ANN001 ) -> Any: diff --git a/strawberry/experimental/pydantic/conversion_types.py b/strawberry/experimental/pydantic/conversion_types.py index 747c67f351..c86bf99105 100644 --- a/strawberry/experimental/pydantic/conversion_types.py +++ b/strawberry/experimental/pydantic/conversion_types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from typing_extensions import Protocol from pydantic import BaseModel @@ -22,7 +22,7 @@ def __init__(self, **kwargs: Any) -> None: ... @staticmethod def from_pydantic( - instance: PydanticModel, extra: Optional[dict[str, Any]] = None + instance: PydanticModel, extra: dict[str, Any] | None = None ) -> StrawberryTypeFromPydantic[PydanticModel]: ... def to_pydantic(self, **kwargs: Any) -> PydanticModel: ... diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index adcde94040..37ef329425 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -6,7 +6,6 @@ TYPE_CHECKING, Any, Optional, - Union, cast, ) @@ -35,13 +34,13 @@ from strawberry.types.base import WithStrawberryObjectDefinition -def get_type_for_field(field: CompatModelField) -> Union[type[Union[None, list]], Any]: +def get_type_for_field(field: CompatModelField) -> type[None | list] | Any: type_ = field.outer_type_ type_ = normalize_type(type_) return field_type_to_type(type_) -def field_type_to_type(type_: type) -> Union[Any, list[Any], None]: +def field_type_to_type(type_: type) -> Any | list[Any] | None: error_class: Any = str strawberry_type: Any = error_class @@ -55,21 +54,21 @@ def field_type_to_type(type_: type) -> Union[Any, list[Any], None]: else: strawberry_type = list[error_class] - strawberry_type = Optional[strawberry_type] + strawberry_type = Optional[strawberry_type] # noqa: UP045 elif lenient_issubclass(type_, BaseModel): strawberry_type = get_strawberry_type_from_model(type_) - return Optional[strawberry_type] + return Optional[strawberry_type] # noqa: UP045 - return Optional[list[strawberry_type]] + return Optional[list[strawberry_type]] # noqa: UP045 def error_type( model: type[BaseModel], *, - fields: Optional[list[str]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + fields: list[str] | None = None, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), all_fields: bool = False, ) -> Callable[..., type]: def wrap(cls: type) -> type: diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index c7531d0c09..6a317fe41f 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -50,7 +50,7 @@ def replace_types_recursively( if isinstance(replaced_type, TypingGenericAlias): return TypingGenericAlias(origin, converted) if isinstance(replaced_type, UnionType): - return Union[converted] + return Union[converted] # noqa: UP007 # TODO: investigate if we could move the check for annotated to the top if origin is Annotated and converted: diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index 42555826e4..e529caa7eb 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -49,7 +49,7 @@ def get_type_for_field(field: CompatModelField, is_input: bool, compat: Pydantic # only pydantic v1 has this Optional logic should_add_optional: bool = field.allow_none if should_add_optional: - return Optional[replaced_type] + return Optional[replaced_type] # noqa: UP045 return replaced_type @@ -118,12 +118,12 @@ def _build_dataclass_creation_fields( def type( model: builtins.type[PydanticModel], *, - fields: Optional[list[str]] = None, - name: Optional[str] = None, + fields: list[str] | None = None, + name: str | None = None, is_input: bool = False, is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), all_fields: bool = False, include_computed: bool = False, use_pydantic_alias: bool = True, @@ -271,7 +271,7 @@ def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool: cls._pydantic_type = model def from_pydantic_default( - instance: PydanticModel, extra: Optional[dict[str, Any]] = None + instance: PydanticModel, extra: dict[str, Any] | None = None ) -> StrawberryTypeFromPydantic[PydanticModel]: ret = convert_pydantic_model_to_strawberry_class( cls=cls, model_instance=instance, extra=extra @@ -302,11 +302,11 @@ def to_pydantic_default(self: Any, **kwargs: Any) -> PydanticModel: def input( model: builtins.type[PydanticModel], *, - fields: Optional[list[str]] = None, - name: Optional[str] = None, + fields: list[str] | None = None, + name: str | None = None, is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), all_fields: bool = False, use_pydantic_alias: bool = True, ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: @@ -332,11 +332,11 @@ def input( def interface( model: builtins.type[PydanticModel], *, - fields: Optional[list[str]] = None, - name: Optional[str] = None, + fields: list[str] | None = None, + name: str | None = None, is_input: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), all_fields: bool = False, use_pydantic_alias: bool = True, ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index adaef2ea13..226589f53c 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -5,7 +5,6 @@ TYPE_CHECKING, Any, NamedTuple, - Union, cast, ) @@ -69,7 +68,7 @@ def to_tuple(self) -> tuple[str, type, dataclasses.Field]: def get_default_factory_for_field( field: CompatModelField, compat: PydanticCompat, -) -> Union[NoArgAnyCallable, dataclasses._MISSING_TYPE]: +) -> NoArgAnyCallable | dataclasses._MISSING_TYPE: """Gets the default factory for a pydantic field. Handles mutable defaults when making the dataclass by diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py index e9431c6ce1..f22dcd307c 100644 --- a/strawberry/ext/mypy_plugin.py +++ b/strawberry/ext/mypy_plugin.py @@ -1,14 +1,11 @@ from __future__ import annotations import re -import typing import warnings from decimal import Decimal from typing import ( TYPE_CHECKING, Any, - Optional, - Union, cast, ) @@ -56,7 +53,7 @@ except ImportError: TypeVarDef = TypeVarType -PYDANTIC_VERSION: Optional[tuple[int, ...]] = None +PYDANTIC_VERSION: tuple[int, ...] | None = None # To be compatible with user who don't use pydantic try: @@ -242,7 +239,7 @@ def enum_hook(ctx: DynamicClassDefContext) -> None: ) return - enum_type: Optional[Type] + enum_type: Type | None try: enum_type = _get_type_for_expr(first_argument, ctx.api) @@ -290,7 +287,7 @@ def scalar_hook(ctx: DynamicClassDefContext) -> None: ) return - scalar_type: Optional[Type] + scalar_type: Type | None # TODO: add proper support for NewType @@ -315,12 +312,12 @@ def scalar_hook(ctx: DynamicClassDefContext) -> None: def add_static_method_to_class( - api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface], + api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, cls: ClassDef, name: str, args: list[Argument], return_type: Type, - tvar_def: Optional[TypeVarType] = None, + tvar_def: TypeVarType | None = None, ) -> None: """Adds a static method. @@ -528,7 +525,7 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: class StrawberryPlugin(Plugin): def get_dynamic_class_hook( self, fullname: str - ) -> Optional[Callable[[DynamicClassDefContext], None]]: + ) -> Callable[[DynamicClassDefContext], None] | None: # TODO: investigate why we need this instead of `strawberry.union.union` on CI # we have the same issue in the other hooks if self._is_strawberry_union(fullname): @@ -545,7 +542,7 @@ def get_dynamic_class_hook( return None - def get_type_analyze_hook(self, fullname: str) -> Union[Callable[..., Type], None]: + def get_type_analyze_hook(self, fullname: str) -> Callable[..., Type] | None: if self._is_strawberry_lazy_type(fullname): return lazy_type_analyze_callback @@ -553,7 +550,7 @@ def get_type_analyze_hook(self, fullname: str) -> Union[Callable[..., Type], Non def get_class_decorator_hook( self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + ) -> Callable[[ClassDefContext], None] | None: if self._is_strawberry_pydantic_decorator(fullname): return strawberry_pydantic_class_callback @@ -614,7 +611,7 @@ def _is_strawberry_pydantic_decorator(self, fullname: str) -> bool: ) -def plugin(version: str) -> typing.Type[StrawberryPlugin]: +def plugin(version: str) -> type[StrawberryPlugin]: match = VERSION_RE.match(version) if match: MypyVersion.VERSION = Decimal(".".join(match.groups())) diff --git a/strawberry/extensions/context.py b/strawberry/extensions/context.py index 19cced1cba..d8629061c8 100644 --- a/strawberry/extensions/context.py +++ b/strawberry/extensions/context.py @@ -9,8 +9,6 @@ TYPE_CHECKING, Any, NamedTuple, - Optional, - Union, ) from strawberry.extensions import SchemaExtension @@ -27,10 +25,8 @@ class WrappedHook(NamedTuple): extension: SchemaExtension hook: Callable[ ..., - Union[ - contextlib.AbstractAsyncContextManager[None], - contextlib.AbstractContextManager[None], - ], + contextlib.AbstractAsyncContextManager[None] + | contextlib.AbstractContextManager[None], ] is_async: bool @@ -64,12 +60,12 @@ def __init__(self, extensions: list[SchemaExtension]) -> None: if hook: self.hooks.append(hook) - def get_hook(self, extension: SchemaExtension) -> Optional[WrappedHook]: + def get_hook(self, extension: SchemaExtension) -> WrappedHook | None: on_start = getattr(extension, self.LEGACY_ENTER, None) on_end = getattr(extension, self.LEGACY_EXIT, None) is_legacy = on_start is not None or on_end is not None - hook_fn: Optional[Hook] = getattr(type(extension), self.HOOK_NAME) + hook_fn: Hook | None = getattr(type(extension), self.HOOK_NAME) hook_fn = hook_fn if hook_fn is not self.default_hook else None if is_legacy and hook_fn is not None: raise ValueError( @@ -110,8 +106,8 @@ def get_hook(self, extension: SchemaExtension) -> Optional[WrappedHook]: @staticmethod def from_legacy( extension: SchemaExtension, - on_start: Optional[Callable[[], None]] = None, - on_end: Optional[Callable[[], None]] = None, + on_start: Callable[[], None] | None = None, + on_end: Callable[[], None] | None = None, ) -> WrappedHook: if iscoroutinefunction(on_start) or iscoroutinefunction(on_end): @@ -175,9 +171,9 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self.exit_stack.__exit__(exc_type, exc_val, exc_tb) @@ -194,9 +190,9 @@ async def __aenter__(self) -> None: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: await self.async_exit_stack.__aexit__(exc_type, exc_val, exc_tb) diff --git a/strawberry/extensions/field_extension.py b/strawberry/extensions/field_extension.py index a01d386bb0..4271d73e1b 100644 --- a/strawberry/extensions/field_extension.py +++ b/strawberry/extensions/field_extension.py @@ -3,7 +3,7 @@ import itertools from collections.abc import Awaitable, Callable from functools import cached_property -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from typing import TypeAlias @@ -69,7 +69,7 @@ def _get_async_resolvers( def build_field_extension_resolvers( field: StrawberryField, -) -> list[Union[SyncExtensionResolver, AsyncExtensionResolver]]: +) -> list[SyncExtensionResolver | AsyncExtensionResolver]: """Builds a list of resolvers for a field with extensions. Verifies that all of the field extensions for a given field support diff --git a/strawberry/extensions/max_aliases.py b/strawberry/extensions/max_aliases.py index c77d524222..ce7e414878 100644 --- a/strawberry/extensions/max_aliases.py +++ b/strawberry/extensions/max_aliases.py @@ -1,5 +1,3 @@ -from typing import Union - from graphql import ( ExecutableDefinitionNode, FieldNode, @@ -64,7 +62,7 @@ def __init__(self, validation_context: ValidationContext) -> None: def count_fields_with_alias( - selection_set_owner: Union[ExecutableDefinitionNode, FieldNode, InlineFragmentNode], + selection_set_owner: ExecutableDefinitionNode | FieldNode | InlineFragmentNode, ) -> int: if selection_set_owner.selection_set is None: return 0 diff --git a/strawberry/extensions/parser_cache.py b/strawberry/extensions/parser_cache.py index e0a2ef6164..9c2766bfbf 100644 --- a/strawberry/extensions/parser_cache.py +++ b/strawberry/extensions/parser_cache.py @@ -1,6 +1,5 @@ from collections.abc import Iterator from functools import lru_cache -from typing import Optional from graphql.language.parser import parse @@ -25,7 +24,7 @@ class ParserCache(SchemaExtension): ``` """ - def __init__(self, maxsize: Optional[int] = None) -> None: + def __init__(self, maxsize: int | None = None) -> None: """Initialize the ParserCache. Args: diff --git a/strawberry/extensions/query_depth_limiter.py b/strawberry/extensions/query_depth_limiter.py index 994ff72a9b..b617648508 100644 --- a/strawberry/extensions/query_depth_limiter.py +++ b/strawberry/extensions/query_depth_limiter.py @@ -32,8 +32,7 @@ from dataclasses import dataclass from typing import ( TYPE_CHECKING, - Optional, - Union, + TypeAlias, ) from graphql import GraphQLError @@ -61,12 +60,17 @@ if TYPE_CHECKING: from collections.abc import Iterable -IgnoreType = Union[Callable[[str], bool], re.Pattern, str] +IgnoreType: TypeAlias = Callable[[str], bool] | re.Pattern | str -FieldArgumentType = Union[ - bool, int, float, str, list["FieldArgumentType"], dict[str, "FieldArgumentType"] -] -FieldArgumentsType = dict[str, FieldArgumentType] +FieldArgumentType: TypeAlias = ( + bool + | int + | float + | str + | list["FieldArgumentType"] + | dict[str, "FieldArgumentType"] +) +FieldArgumentsType: TypeAlias = dict[str, FieldArgumentType] @dataclass @@ -99,8 +103,8 @@ class QueryDepthLimiter(AddValidationRules): def __init__( self, max_depth: int, - callback: Optional[Callable[[dict[str, int]], None]] = None, - should_ignore: Optional[ShouldIgnoreType] = None, + callback: Callable[[dict[str, int]], None] | None = None, + should_ignore: ShouldIgnoreType | None = None, ) -> None: """Initialize the QueryDepthLimiter. @@ -122,8 +126,8 @@ def __init__( def create_validator( max_depth: int, - should_ignore: Optional[ShouldIgnoreType], - callback: Optional[Callable[[dict[str, int]], None]] = None, + should_ignore: ShouldIgnoreType | None, + callback: Callable[[dict[str, int]], None] | None = None, ) -> type[ValidationRule]: class DepthLimitValidator(ValidationRule): def __init__(self, validation_context: ValidationContext) -> None: @@ -218,7 +222,7 @@ def determine_depth( max_depth: int, context: ValidationContext, operation_name: str, - should_ignore: Optional[ShouldIgnoreType], + should_ignore: ShouldIgnoreType | None, ) -> int: if depth_so_far > max_depth: context.report_error( @@ -288,7 +292,7 @@ def determine_depth( raise TypeError(f"Depth crawler cannot handle: {node.kind}") # pragma: no cover -def is_ignored(node: FieldNode, ignore: Optional[list[IgnoreType]] = None) -> bool: +def is_ignored(node: FieldNode, ignore: list[IgnoreType] | None = None) -> bool: if ignore is None: return False diff --git a/strawberry/extensions/runner.py b/strawberry/extensions/runner.py index 3f307807f8..fbe3b40cf9 100644 --- a/strawberry/extensions/runner.py +++ b/strawberry/extensions/runner.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from strawberry.extensions.context import ( ExecutingContextManager, @@ -23,7 +23,7 @@ class SchemaExtensionsRunner: def __init__( self, execution_context: ExecutionContext, - extensions: Optional[list[SchemaExtension]] = None, + extensions: list[SchemaExtension] | None = None, ) -> None: self.execution_context = execution_context self.extensions = extensions or [] diff --git a/strawberry/extensions/tracing/apollo.py b/strawberry/extensions/tracing/apollo.py index 3c5af96344..559e4c12e1 100644 --- a/strawberry/extensions/tracing/apollo.py +++ b/strawberry/extensions/tracing/apollo.py @@ -4,7 +4,7 @@ import time from datetime import datetime, timezone from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from strawberry.extensions import SchemaExtension from strawberry.extensions.utils import get_path_from_info @@ -38,7 +38,7 @@ class ApolloResolverStats: field_name: str return_type: Any start_offset: int - duration: Optional[int] = None + duration: int | None = None def to_json(self) -> dict[str, Any]: return { diff --git a/strawberry/extensions/tracing/datadog.py b/strawberry/extensions/tracing/datadog.py index 1e48ddf00d..20aa1ee779 100644 --- a/strawberry/extensions/tracing/datadog.py +++ b/strawberry/extensions/tracing/datadog.py @@ -3,7 +3,7 @@ import hashlib from functools import cached_property from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import ddtrace from packaging import version @@ -30,7 +30,7 @@ class DatadogTracingExtension(SchemaExtension): def __init__( self, *, - execution_context: Optional[ExecutionContext] = None, + execution_context: ExecutionContext | None = None, ) -> None: if execution_context: self.execution_context = execution_context diff --git a/strawberry/extensions/tracing/opentelemetry.py b/strawberry/extensions/tracing/opentelemetry.py index a4439e8b14..e2f219cc82 100644 --- a/strawberry/extensions/tracing/opentelemetry.py +++ b/strawberry/extensions/tracing/opentelemetry.py @@ -6,8 +6,6 @@ from typing import ( TYPE_CHECKING, Any, - Optional, - Union, ) from opentelemetry import trace @@ -33,16 +31,16 @@ class OpenTelemetryExtension(SchemaExtension): - _arg_filter: Optional[ArgFilter] + _arg_filter: ArgFilter | None _span_holder: dict[LifecycleStep, Span] _tracer: Tracer def __init__( self, *, - execution_context: Optional[ExecutionContext] = None, - arg_filter: Optional[ArgFilter] = None, - tracer_provider: Optional[trace.TracerProvider] = None, + execution_context: ExecutionContext | None = None, + arg_filter: ArgFilter | None = None, + tracer_provider: trace.TracerProvider | None = None, ) -> None: self._arg_filter = arg_filter self._tracer = trace.get_tracer("strawberry", tracer_provider=tracer_provider) @@ -129,7 +127,7 @@ def convert_to_allowed_types(self, value: Any) -> Any: return bytes(value) # Convert bytearray and memoryview to bytes return str(value) - def convert_set_to_allowed_types(self, value: Union[set, frozenset]) -> str: + def convert_set_to_allowed_types(self, value: set | frozenset) -> str: return ( "{" + ", ".join(str(self.convert_to_allowed_types(x)) for x in value) + "}" ) diff --git a/strawberry/extensions/utils.py b/strawberry/extensions/utils.py index e47ec681e2..b49058df7b 100644 --- a/strawberry/extensions/utils.py +++ b/strawberry/extensions/utils.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING if TYPE_CHECKING: from graphql import GraphQLResolveInfo -def is_introspection_key(key: Union[str, int]) -> bool: +def is_introspection_key(key: str | int) -> bool: # from: https://spec.graphql.org/June2018/#sec-Schema # > All types and directives defined within a schema must not have a name which # > begins with "__" (two underscores), as this is used exclusively diff --git a/strawberry/extensions/validation_cache.py b/strawberry/extensions/validation_cache.py index 0de1640547..46b1ae0c60 100644 --- a/strawberry/extensions/validation_cache.py +++ b/strawberry/extensions/validation_cache.py @@ -1,6 +1,5 @@ from collections.abc import Iterator from functools import lru_cache -from typing import Optional from strawberry.extensions.base_extension import SchemaExtension @@ -22,7 +21,7 @@ class ValidationCache(SchemaExtension): ``` """ - def __init__(self, maxsize: Optional[int] = None) -> None: + def __init__(self, maxsize: int | None = None) -> None: """Initialize the ValidationCache. Args: diff --git a/strawberry/fastapi/context.py b/strawberry/fastapi/context.py index 1c79711698..584fd4a6da 100644 --- a/strawberry/fastapi/context.py +++ b/strawberry/fastapi/context.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Union from starlette.background import BackgroundTasks from starlette.requests import Request @@ -7,17 +7,17 @@ CustomContext = Union["BaseContext", dict[str, Any]] MergedContext = Union[ - "BaseContext", dict[str, Union[Any, BackgroundTasks, Request, Response, WebSocket]] + "BaseContext", dict[str, Any | BackgroundTasks | Request | Response | WebSocket] ] class BaseContext: - connection_params: Optional[Any] = None + connection_params: Any | None = None def __init__(self) -> None: - self.request: Optional[Union[Request, WebSocket]] = None - self.background_tasks: Optional[BackgroundTasks] = None - self.response: Optional[Response] = None + self.request: Request | WebSocket | None = None + self.background_tasks: BackgroundTasks | None = None + self.response: Response | None = None __all__ = ["BaseContext"] diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index cea96ded0f..5d598cfaa2 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -6,9 +6,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, TypeGuard, - Union, cast, ) @@ -71,16 +69,16 @@ async def __get_root_value() -> None: @staticmethod def __get_context_getter( custom_getter: Callable[ - ..., Union[Optional[CustomContext], Awaitable[Optional[CustomContext]]] + ..., CustomContext | None | Awaitable[CustomContext | None] ], ) -> Callable[..., Awaitable[CustomContext]]: async def dependency( - custom_context: Optional[CustomContext], + custom_context: CustomContext | None, background_tasks: BackgroundTasks, connection: HTTPConnection, response: Response = None, # type: ignore ) -> MergedContext: - request = cast("Union[Request, WebSocket]", connection) + request = cast("Request | WebSocket", connection) if isinstance(custom_context, BaseContext): custom_context.request = request custom_context.background_tasks = background_tasks @@ -122,35 +120,34 @@ def __init__( self, schema: BaseSchema, path: str = "", - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, - root_value_getter: Optional[Callable[[], RootValue]] = None, - context_getter: Optional[ - Callable[..., Union[Optional[Context], Awaitable[Optional[Context]]]] - ] = None, + root_value_getter: Callable[[], RootValue] | None = None, + context_getter: Callable[..., Context | None | Awaitable[Context | None]] + | None = None, subscription_protocols: Sequence[str] = ( GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), prefix: str = "", - tags: Optional[list[Union[str, Enum]]] = None, - dependencies: Optional[Sequence[params.Depends]] = None, + tags: list[str | Enum] | None = None, + dependencies: Sequence[params.Depends] | None = None, default_response_class: type[Response] = Default(JSONResponse), - responses: Optional[dict[Union[int, str], dict[str, Any]]] = None, - callbacks: Optional[list[BaseRoute]] = None, - routes: Optional[list[BaseRoute]] = None, + responses: dict[int | str, dict[str, Any]] | None = None, + callbacks: list[BaseRoute] | None = None, + routes: list[BaseRoute] | None = None, redirect_slashes: bool = True, - default: Optional[ASGIApp] = None, - dependency_overrides_provider: Optional[Any] = None, + default: ASGIApp | None = None, + dependency_overrides_provider: Any | None = None, route_class: type[APIRoute] = APIRoute, - on_startup: Optional[Sequence[Callable[[], Any]]] = None, - on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, - lifespan: Optional[Lifespan[Any]] = None, - deprecated: Optional[bool] = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[Any] | None = None, + deprecated: bool | None = None, include_in_schema: bool = True, generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id @@ -265,13 +262,13 @@ async def render_graphql_ide(self, request: Request) -> HTMLResponse: return HTMLResponse(self.graphql_ide_html) async def get_context( - self, request: Union[Request, WebSocket], response: Union[Response, WebSocket] + self, request: Request | WebSocket, response: Response | WebSocket ) -> Context: # pragma: no cover raise ValueError("`get_context` is not used by FastAPI GraphQL Router") async def get_root_value( - self, request: Union[Request, WebSocket] - ) -> Optional[RootValue]: # pragma: no cover + self, request: Request | WebSocket + ) -> RootValue | None: # pragma: no cover raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router") async def get_sub_response(self, request: Request) -> Response: @@ -279,7 +276,7 @@ async def get_sub_response(self, request: Request) -> Response: def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: Response, ) -> Response: response = Response( @@ -309,18 +306,18 @@ async def create_streaming_response( ) def is_websocket_request( - self, request: Union[Request, WebSocket] + self, request: Request | WebSocket ) -> TypeGuard[WebSocket]: return request.scope["type"] == "websocket" - async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: WebSocket) -> str | None: protocols = request["subprotocols"] intersection = set(protocols) & set(self.protocols) sorted_intersection = sorted(intersection, key=protocols.index) return next(iter(sorted_intersection), None) async def create_websocket_response( - self, request: WebSocket, subprotocol: Optional[str] + self, request: WebSocket, subprotocol: str | None ) -> WebSocket: await request.accept(subprotocol=subprotocol) return request diff --git a/strawberry/federation/argument.py b/strawberry/federation/argument.py index 9c42fad6cc..123b655f05 100644 --- a/strawberry/federation/argument.py +++ b/strawberry/federation/argument.py @@ -1,16 +1,15 @@ from collections.abc import Iterable -from typing import Optional from strawberry.types.arguments import StrawberryArgumentAnnotation def argument( - description: Optional[str] = None, - name: Optional[str] = None, - deprecation_reason: Optional[str] = None, + description: str | None = None, + name: str | None = None, + deprecation_reason: str | None = None, directives: Iterable[object] = (), inaccessible: bool = False, - tags: Optional[Iterable[str]] = (), + tags: Iterable[str] | None = (), ) -> StrawberryArgumentAnnotation: from strawberry.federation.schema_directives import Inaccessible, Tag diff --git a/strawberry/federation/enum.py b/strawberry/federation/enum.py index 0ced896e9c..821448eaf9 100644 --- a/strawberry/federation/enum.py +++ b/strawberry/federation/enum.py @@ -3,8 +3,6 @@ from typing import ( TYPE_CHECKING, Any, - Optional, - Union, overload, ) @@ -19,8 +17,8 @@ def enum_value( value: Any, - name: Optional[str] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + deprecation_reason: str | None = None, directives: Iterable[object] = (), inaccessible: bool = False, tags: Iterable[str] = (), @@ -47,14 +45,14 @@ def enum_value( def enum( _cls: EnumType, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), ) -> EnumType: ... @@ -62,29 +60,29 @@ def enum( def enum( _cls: None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), ) -> Callable[[EnumType], EnumType]: ... def enum( - _cls: Optional[EnumType] = None, + _cls: EnumType | None = None, *, name=None, description=None, directives=(), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), -) -> Union[EnumType, Callable[[EnumType], EnumType]]: + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), +) -> EnumType | Callable[[EnumType], EnumType]: """Registers the enum in the GraphQL type system. If name is passed, the name of the GraphQL type will be diff --git a/strawberry/federation/field.py b/strawberry/federation/field.py index ca4ea18477..815fca2640 100644 --- a/strawberry/federation/field.py +++ b/strawberry/federation/field.py @@ -4,9 +4,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, TypeVar, - Union, overload, ) @@ -38,28 +36,28 @@ def field( *, resolver: _RESOLVER_TYPE_ASYNC[T], - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - provides: Optional[list[str]] = None, - override: Optional[Union[Override, str]] = None, - requires: Optional[list[str]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + provides: list[str] | None = None, + override: Override | str | None = None, + requires: list[str] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), shareable: bool = False, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @@ -67,56 +65,56 @@ def field( def field( *, resolver: _RESOLVER_TYPE_SYNC[T], - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - provides: Optional[list[str]] = None, - override: Optional[Union[Override, str]] = None, - requires: Optional[list[str]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + provides: list[str] | None = None, + override: Override | str | None = None, + requires: list[str] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), shareable: bool = False, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @overload def field( *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - provides: Optional[list[str]] = None, - override: Optional[Union[Override, str]] = None, - requires: Optional[list[str]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + provides: list[str] | None = None, + override: Override | str | None = None, + requires: list[str] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), shareable: bool = False, init: Literal[True] = True, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> Any: ... @@ -124,27 +122,27 @@ def field( def field( resolver: _RESOLVER_TYPE_ASYNC[T], *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - provides: Optional[list[str]] = None, - override: Optional[Union[Override, str]] = None, - requires: Optional[list[str]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + provides: list[str] | None = None, + override: Override | str | None = None, + requires: list[str] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), shareable: bool = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... @@ -152,54 +150,54 @@ def field( def field( resolver: _RESOLVER_TYPE_SYNC[T], *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - provides: Optional[list[str]] = None, - override: Optional[Union[Override, str]] = None, - requires: Optional[list[str]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + provides: list[str] | None = None, + override: Override | str | None = None, + requires: list[str] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), shareable: bool = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... def field( - resolver: Optional[_RESOLVER_TYPE[Any]] = None, + resolver: _RESOLVER_TYPE[Any] | None = None, *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, authenticated: bool = False, external: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - provides: Optional[list[str]] = None, - override: Optional[Union[Override, str]] = None, - requires: Optional[list[str]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + provides: list[str] | None = None, + override: Override | str | None = None, + requires: list[str] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), shareable: bool = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, # This init parameter is used by PyRight to determine whether this field # is added in the constructor or not. It is not used to change # any behavior at the moment. diff --git a/strawberry/federation/object_type.py b/strawberry/federation/object_type.py index 45926ba717..5a3843d878 100644 --- a/strawberry/federation/object_type.py +++ b/strawberry/federation/object_type.py @@ -2,7 +2,6 @@ from collections.abc import Callable, Iterable, Sequence from typing import ( TYPE_CHECKING, - Optional, TypeVar, Union, overload, @@ -24,19 +23,19 @@ def _impl_type( - cls: Optional[T], + cls: T | None, *, - name: Optional[str] = None, - description: Optional[str] = None, - one_of: Optional[bool] = None, + name: str | None = None, + description: str | None = None, + one_of: bool | None = None, directives: Iterable[object] = (), authenticated: bool = False, keys: Iterable[Union["Key", str]] = (), extend: bool = False, shareable: bool = False, inaccessible: bool = UNSET, - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, tags: Iterable[str] = (), is_input: bool = False, is_interface: bool = False, @@ -105,15 +104,15 @@ def _impl_type( def type( cls: T, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, extend: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, shareable: bool = False, tags: Iterable[str] = (), ) -> T: ... @@ -127,32 +126,32 @@ def type( ) def type( *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, extend: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, shareable: bool = False, tags: Iterable[str] = (), ) -> Callable[[T], T]: ... def type( - cls: Optional[T] = None, + cls: T | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, extend: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, shareable: bool = False, tags: Iterable[str] = (), ): @@ -181,9 +180,9 @@ def type( def input( cls: T, *, - name: Optional[str] = None, - one_of: Optional[bool] = None, - description: Optional[str] = None, + name: str | None = None, + one_of: bool | None = None, + description: str | None = None, directives: Sequence[object] = (), inaccessible: bool = UNSET, tags: Iterable[str] = (), @@ -198,9 +197,9 @@ def input( ) def input( *, - name: Optional[str] = None, - description: Optional[str] = None, - one_of: Optional[bool] = None, + name: str | None = None, + description: str | None = None, + one_of: bool | None = None, directives: Sequence[object] = (), inaccessible: bool = UNSET, tags: Iterable[str] = (), @@ -208,11 +207,11 @@ def input( def input( - cls: Optional[T] = None, + cls: T | None = None, *, - name: Optional[str] = None, - one_of: Optional[bool] = None, - description: Optional[str] = None, + name: str | None = None, + one_of: bool | None = None, + description: str | None = None, directives: Sequence[object] = (), inaccessible: bool = UNSET, tags: Iterable[str] = (), @@ -238,14 +237,14 @@ def input( def interface( cls: T, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, tags: Iterable[str] = (), ) -> T: ... @@ -258,29 +257,29 @@ def interface( ) def interface( *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, tags: Iterable[str] = (), ) -> Callable[[T], T]: ... def interface( - cls: Optional[T] = None, + cls: T | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, tags: Iterable[str] = (), ): return _impl_type( @@ -307,14 +306,14 @@ def interface( def interface_object( cls: T, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, tags: Iterable[str] = (), ) -> T: ... @@ -327,29 +326,29 @@ def interface_object( ) def interface_object( *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, tags: Iterable[str] = (), ) -> Callable[[T], T]: ... def interface_object( - cls: Optional[T] = None, + cls: T | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = UNSET, keys: Iterable[Union["Key", str]] = (), - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, tags: Iterable[str] = (), ): return _impl_type( diff --git a/strawberry/federation/scalar.py b/strawberry/federation/scalar.py index 9203469e0f..95c3f0d459 100644 --- a/strawberry/federation/scalar.py +++ b/strawberry/federation/scalar.py @@ -2,15 +2,13 @@ from typing import ( Any, NewType, - Optional, TypeVar, - Union, overload, ) from strawberry.types.scalar import ScalarWrapper, _process_scalar -_T = TypeVar("_T", bound=Union[type, NewType]) +_T = TypeVar("_T", bound=type | NewType) def identity(x: _T) -> _T: # pragma: no cover @@ -20,18 +18,18 @@ def identity(x: _T) -> _T: # pragma: no cover @overload def scalar( *, - name: Optional[str] = None, - description: Optional[str] = None, - specified_by_url: Optional[str] = None, + name: str | None = None, + description: str | None = None, + specified_by_url: str | None = None, serialize: Callable = identity, - parse_value: Optional[Callable] = None, - parse_literal: Optional[Callable] = None, + parse_value: Callable | None = None, + parse_literal: Callable | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), ) -> Callable[[_T], _T]: ... @@ -39,36 +37,36 @@ def scalar( def scalar( cls: _T, *, - name: Optional[str] = None, - description: Optional[str] = None, - specified_by_url: Optional[str] = None, + name: str | None = None, + description: str | None = None, + specified_by_url: str | None = None, serialize: Callable = identity, - parse_value: Optional[Callable] = None, - parse_literal: Optional[Callable] = None, + parse_value: Callable | None = None, + parse_literal: Callable | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), ) -> _T: ... def scalar( - cls: Optional[_T] = None, + cls: _T | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, - specified_by_url: Optional[str] = None, + name: str | None = None, + description: str | None = None, + specified_by_url: str | None = None, serialize: Callable = identity, - parse_value: Optional[Callable] = None, - parse_literal: Optional[Callable] = None, + parse_value: Callable | None = None, + parse_literal: Callable | None = None, directives: Iterable[object] = (), authenticated: bool = False, inaccessible: bool = False, - policy: Optional[list[list[str]]] = None, - requires_scopes: Optional[list[list[str]]] = None, - tags: Optional[Iterable[str]] = (), + policy: list[list[str]] | None = None, + requires_scopes: list[list[str]] | None = None, + tags: Iterable[str] | None = (), ) -> Any: """Annotates a class or type as a GraphQL custom scalar. diff --git a/strawberry/federation/schema.py b/strawberry/federation/schema.py index 72df2401a2..4bb2359fa3 100644 --- a/strawberry/federation/schema.py +++ b/strawberry/federation/schema.py @@ -44,18 +44,17 @@ class Schema(BaseSchema): def __init__( self, - query: Optional[type] = None, - mutation: Optional[type] = None, - subscription: Optional[type] = None, + query: type | None = None, + mutation: type | None = None, + subscription: type | None = None, # TODO: we should update directives' type in the main schema directives: Iterable[type] = (), types: Iterable[type] = (), extensions: Iterable[Union[type["SchemaExtension"], "SchemaExtension"]] = (), - execution_context_class: Optional[type["GraphQLExecutionContext"]] = None, + execution_context_class: type["GraphQLExecutionContext"] | None = None, config: Optional["StrawberryConfig"] = None, - scalar_overrides: Optional[ - dict[object, Union[type, "ScalarWrapper", "ScalarDefinition"]] - ] = None, + scalar_overrides: dict[object, Union[type, "ScalarWrapper", "ScalarDefinition"]] + | None = None, schema_directives: Iterable[object] = (), enable_federation_2: bool = False, ) -> None: @@ -85,9 +84,9 @@ def __init__( def _get_federation_query_type( self, - query: Optional[type[WithStrawberryObjectDefinition]], - mutation: Optional[type[WithStrawberryObjectDefinition]], - subscription: Optional[type[WithStrawberryObjectDefinition]], + query: type[WithStrawberryObjectDefinition] | None, + mutation: type[WithStrawberryObjectDefinition] | None, + subscription: type[WithStrawberryObjectDefinition] | None, additional_types: Iterable[type[WithStrawberryObjectDefinition]], ) -> type: """Returns a new query type that includes the _service field. @@ -124,7 +123,7 @@ def service() -> Service: if entity_type: self.entities_resolver.__annotations__["return"] = list[ - Optional[entity_type] # type: ignore + entity_type | None # type: ignore ] entities_field = strawberry.field( @@ -250,7 +249,7 @@ def _add_link_for_composed_directive( directive_by_url[import_url].add(f"@{name}") def _add_link_directives( - self, additional_directives: Optional[list[object]] = None + self, additional_directives: list[object] | None = None ) -> None: from .schema_directives import FederationDirective, Link @@ -312,11 +311,11 @@ def _warn_for_federation_directives(self) -> None: def _get_entity_type( - query: Optional[type[WithStrawberryObjectDefinition]], - mutation: Optional[type[WithStrawberryObjectDefinition]], - subscription: Optional[type[WithStrawberryObjectDefinition]], + query: type[WithStrawberryObjectDefinition] | None, + mutation: type[WithStrawberryObjectDefinition] | None, + subscription: type[WithStrawberryObjectDefinition] | None, additional_types: Iterable[type[WithStrawberryObjectDefinition]], -) -> Optional[StrawberryUnion]: +) -> StrawberryUnion | None: # recursively iterate over the schema to find all types annotated with @key # if no types are annotated with @key, then the _Entity union and Query._entities # field should not be added to the schema diff --git a/strawberry/federation/schema_directive.py b/strawberry/federation/schema_directive.py index 1140ca3310..2803d0b886 100644 --- a/strawberry/federation/schema_directive.py +++ b/strawberry/federation/schema_directive.py @@ -1,6 +1,6 @@ import dataclasses from collections.abc import Callable -from typing import Optional, TypeVar +from typing import TypeVar from typing_extensions import dataclass_transform from strawberry.directive import directive_field @@ -12,12 +12,12 @@ @dataclasses.dataclass class ComposeOptions: - import_url: Optional[str] + import_url: str | None @dataclasses.dataclass class StrawberryFederationSchemaDirective(StrawberrySchemaDirective): - compose_options: Optional[ComposeOptions] = None + compose_options: ComposeOptions | None = None T = TypeVar("T", bound=type) @@ -31,12 +31,12 @@ class StrawberryFederationSchemaDirective(StrawberrySchemaDirective): def schema_directive( *, locations: list[Location], - description: Optional[str] = None, - name: Optional[str] = None, + description: str | None = None, + name: str | None = None, repeatable: bool = False, print_definition: bool = True, compose: bool = False, - import_url: Optional[str] = None, + import_url: str | None = None, ) -> Callable[[T], T]: def _wrap(cls: T) -> T: cls = _wrap_dataclass(cls) # type: ignore diff --git a/strawberry/federation/schema_directives.py b/strawberry/federation/schema_directives.py index c249f7d212..a1b4b22e8e 100644 --- a/strawberry/federation/schema_directives.py +++ b/strawberry/federation/schema_directives.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar from strawberry import directive_field from strawberry.schema_directive import Location, schema_directive @@ -59,7 +59,7 @@ class Provides(FederationDirective): ) class Key(FederationDirective): fields: FieldSet - resolvable: Optional[bool] = True + resolvable: bool | None = True imported_from: ClassVar[ImportedFrom] = ImportedFrom( name="key", url="https://specs.apollo.dev/federation/v2.7" ) @@ -81,17 +81,17 @@ class Shareable(FederationDirective): locations=[Location.SCHEMA], name="link", repeatable=True, print_definition=False ) class Link: - url: Optional[str] - as_: Optional[str] = directive_field(name="as") - for_: Optional[LinkPurpose] = directive_field(name="for") - import_: Optional[list[Optional[LinkImport]]] = directive_field(name="import") + url: str | None + as_: str | None = directive_field(name="as") + for_: LinkPurpose | None = directive_field(name="for") + import_: list[LinkImport | None] | None = directive_field(name="import") def __init__( self, - url: Optional[str] = UNSET, - as_: Optional[str] = UNSET, - for_: Optional[LinkPurpose] = UNSET, - import_: Optional[list[Optional[LinkImport]]] = UNSET, + url: str | None = UNSET, + as_: str | None = UNSET, + for_: LinkPurpose | None = UNSET, + import_: list[LinkImport | None] | None = UNSET, ) -> None: self.url = url self.as_ = as_ @@ -128,7 +128,7 @@ class Tag(FederationDirective): ) class Override(FederationDirective): override_from: str = directive_field(name="from") - label: Optional[str] = UNSET + label: str | None = UNSET imported_from: ClassVar[ImportedFrom] = ImportedFrom( name="override", url="https://specs.apollo.dev/federation/v2.7" ) diff --git a/strawberry/federation/union.py b/strawberry/federation/union.py index b3d1c8a2ce..4ce06989d6 100644 --- a/strawberry/federation/union.py +++ b/strawberry/federation/union.py @@ -1,5 +1,5 @@ from collections.abc import Collection, Iterable -from typing import Any, Optional +from typing import Any from strawberry.types.union import StrawberryUnion from strawberry.types.union import union as base_union @@ -7,12 +7,12 @@ def union( name: str, - types: Optional[Collection[type[Any]]] = None, + types: Collection[type[Any]] | None = None, *, - description: Optional[str] = None, + description: str | None = None, directives: Iterable[object] = (), inaccessible: bool = False, - tags: Optional[Iterable[str]] = (), + tags: Iterable[str] | None = (), ) -> StrawberryUnion: """Creates a new named Union type. diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index 73a737a3ac..798d09ea15 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -4,9 +4,7 @@ from typing import ( TYPE_CHECKING, ClassVar, - Optional, TypeGuard, - Union, ) from lia import AsyncFlaskHTTPRequestAdapter, FlaskHTTPRequestAdapter, HTTPException @@ -25,13 +23,13 @@ class BaseGraphQLView: - graphql_ide: Optional[GraphQL_IDE] + graphql_ide: GraphQL_IDE | None def __init__( self, schema: BaseSchema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, ) -> None: @@ -52,7 +50,7 @@ def __init__( def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: Response, ) -> Response: sub_response.set_data(self.encode_json(response_data)) # type: ignore @@ -72,7 +70,7 @@ class GraphQLView( def get_context(self, request: Request, response: Response) -> Context: return {"request": request, "response": response} # type: ignore - def get_root_value(self, request: Request) -> Optional[RootValue]: + def get_root_value(self, request: Request) -> RootValue | None: return None def get_sub_response(self, request: Request) -> Response: @@ -105,7 +103,7 @@ class AsyncGraphQLView( async def get_context(self, request: Request, response: Response) -> Context: return {"request": request, "response": response} # type: ignore - async def get_root_value(self, request: Request) -> Optional[RootValue]: + async def get_root_value(self, request: Request) -> RootValue | None: return None async def get_sub_response(self, request: Request) -> Response: @@ -127,11 +125,11 @@ async def render_graphql_ide(self, request: Request) -> Response: def is_websocket_request(self, request: Request) -> TypeGuard[Request]: return False - async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: Request) -> str | None: raise NotImplementedError async def create_websocket_response( - self, request: Request, subprotocol: Optional[str] + self, request: Request, subprotocol: str | None ) -> Response: raise NotImplementedError diff --git a/strawberry/http/__init__.py b/strawberry/http/__init__.py index 10fc766d70..b0d39b963a 100644 --- a/strawberry/http/__init__.py +++ b/strawberry/http/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import TypedDict from strawberry.schema._graphql_core import ( @@ -11,14 +11,14 @@ class GraphQLHTTPResponse(TypedDict, total=False): - data: Optional[dict[str, object]] - errors: Optional[list[object]] - extensions: Optional[dict[str, object]] - hasNext: Optional[bool] - completed: Optional[list[Any]] - pending: Optional[list[Any]] - initial: Optional[list[Any]] - incremental: Optional[list[Any]] + data: dict[str, object] | None + errors: list[object] | None + extensions: dict[str, object] | None + hasNext: bool | None + completed: list[Any] | None + pending: list[Any] | None + initial: list[Any] | None + incremental: list[Any] | None def process_result(result: ResultType) -> GraphQLHTTPResponse: @@ -39,10 +39,10 @@ def process_result(result: ResultType) -> GraphQLHTTPResponse: class GraphQLRequestData: # query is optional here as it can be added by an extensions # (for example an extension for persisted queries) - query: Optional[str] - variables: Optional[dict[str, Any]] - operation_name: Optional[str] - extensions: Optional[dict[str, Any]] + query: str | None + variables: dict[str, Any] | None + operation_name: str | None + extensions: dict[str, Any] | None protocol: Literal["http", "multipart-subscription"] = "http" diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 1a88c70766..a9fc22df6c 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -8,9 +8,7 @@ Any, Generic, Literal, - Optional, TypeGuard, - Union, cast, overload, ) @@ -86,9 +84,9 @@ class AsyncBaseHTTPView( ], ): schema: BaseSchema - graphql_ide: Optional[GraphQL_IDE] + graphql_ide: GraphQL_IDE | None keep_alive = False - keep_alive_interval: Optional[float] = None + keep_alive_interval: float | None = None connection_init_wait_timeout: timedelta = timedelta(minutes=1) request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter] websocket_adapter_class: Callable[ @@ -116,19 +114,19 @@ async def get_sub_response(self, request: Request) -> SubResponse: ... @abc.abstractmethod async def get_context( self, - request: Union[Request, WebSocketRequest], - response: Union[SubResponse, WebSocketResponse], + request: Request | WebSocketRequest, + response: SubResponse | WebSocketResponse, ) -> Context: ... @abc.abstractmethod async def get_root_value( - self, request: Union[Request, WebSocketRequest] - ) -> Optional[RootValue]: ... + self, request: Request | WebSocketRequest + ) -> RootValue | None: ... @abc.abstractmethod def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: SubResponse, ) -> Response: ... @@ -146,26 +144,26 @@ async def create_streaming_response( @abc.abstractmethod def is_websocket_request( - self, request: Union[Request, WebSocketRequest] + self, request: Request | WebSocketRequest ) -> TypeGuard[WebSocketRequest]: ... @abc.abstractmethod async def pick_websocket_subprotocol( self, request: WebSocketRequest - ) -> Optional[str]: ... + ) -> str | None: ... @abc.abstractmethod async def create_websocket_response( - self, request: WebSocketRequest, subprotocol: Optional[str] + self, request: WebSocketRequest, subprotocol: str | None ) -> WebSocketResponse: ... async def execute_operation( self, request: Request, context: Context, - root_value: Optional[RootValue], + root_value: RootValue | None, sub_response: SubResponse, - ) -> Union[ExecutionResult, list[ExecutionResult], SubscriptionExecutionResult]: + ) -> ExecutionResult | list[ExecutionResult] | SubscriptionExecutionResult: request_adapter = self.request_adapter_class(request) try: @@ -222,7 +220,7 @@ async def execute_single( request_adapter: AsyncHTTPRequestAdapter, sub_response: SubResponse, context: Context, - root_value: Optional[RootValue], + root_value: RootValue | None, request_data: GraphQLRequestData, ) -> ExecutionResult: allowed_operation_types = OperationType.from_http(request_adapter.method) @@ -282,7 +280,7 @@ async def run( self, request: Request, context: Context = UNSET, - root_value: Optional[RootValue] = UNSET, + root_value: RootValue | None = UNSET, ) -> Response: ... @overload @@ -290,15 +288,15 @@ async def run( self, request: WebSocketRequest, context: Context = UNSET, - root_value: Optional[RootValue] = UNSET, + root_value: RootValue | None = UNSET, ) -> WebSocketResponse: ... async def run( self, - request: Union[Request, WebSocketRequest], + request: Request | WebSocketRequest, context: Context = UNSET, - root_value: Optional[RootValue] = UNSET, - ) -> Union[Response, WebSocketResponse]: + root_value: RootValue | None = UNSET, + ) -> Response | WebSocketResponse: root_value = ( await self.get_root_value(request) if root_value is UNSET else root_value ) @@ -446,7 +444,7 @@ async def stream() -> AsyncGenerator[str, None]: }, ) - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]] + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse] if isinstance(result, list): response_data = [] @@ -617,7 +615,7 @@ async def parse_multipart_subscriptions( async def parse_http_body( self, request: AsyncHTTPRequestAdapter - ) -> Union[GraphQLRequestData, list[GraphQLRequestData]]: + ) -> GraphQLRequestData | list[GraphQLRequestData]: headers = {key.lower(): value for key, value in request.headers.items()} content_type, _ = parse_content_type(request.content_type or "") accept = headers.get("accept", "") @@ -686,7 +684,7 @@ async def process_result( async def on_ws_connect( self, context: Context - ) -> Union[UnsetType, None, dict[str, object]]: + ) -> UnsetType | None | dict[str, object]: return UNSET diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 03421e64f7..4486ecf172 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping -from typing import Any, Generic, Optional, Union +from typing import Any, Generic from typing_extensions import Protocol from lia import HTTPException @@ -15,7 +15,7 @@ class BaseRequestProtocol(Protocol): @property - def query_params(self) -> Mapping[str, Optional[Union[str, list[str]]]]: ... + def query_params(self) -> Mapping[str, str | list[str] | None]: ... @property def method(self) -> HTTPMethod: ... @@ -25,7 +25,7 @@ def headers(self) -> Mapping[str, str]: ... class BaseView(Generic[Request]): - graphql_ide: Optional[GraphQL_IDE] + graphql_ide: GraphQL_IDE | None multipart_uploads_enabled: bool = False schema: BaseSchema @@ -42,13 +42,13 @@ def should_render_graphql_ide(self, request: BaseRequestProtocol) -> bool: def is_request_allowed(self, request: BaseRequestProtocol) -> bool: return request.method in ("GET", "POST") - def parse_json(self, data: Union[str, bytes]) -> Any: + def parse_json(self, data: str | bytes) -> Any: try: return self.decode_json(data) except json.JSONDecodeError as e: raise HTTPException(400, "Unable to parse request body as JSON") from e - def decode_json(self, data: Union[str, bytes]) -> object: + def decode_json(self, data: str | bytes) -> object: return json.loads(data) def encode_json(self, data: object) -> str: diff --git a/strawberry/http/ides.py b/strawberry/http/ides.py index 6fb3b7ab26..16bb8722aa 100644 --- a/strawberry/http/ides.py +++ b/strawberry/http/ides.py @@ -1,11 +1,11 @@ import pathlib -from typing import Literal, Optional +from typing import Literal GraphQL_IDE = Literal["graphiql", "apollo-sandbox", "pathfinder"] def get_graphql_ide_html( - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphql_ide: GraphQL_IDE | None = "graphiql", ) -> str: here = pathlib.Path(__file__).parents[1] diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 9953719ca1..1d4529ed66 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -4,8 +4,6 @@ from typing import ( Generic, Literal, - Optional, - Union, ) from graphql import GraphQLError @@ -39,8 +37,8 @@ class SyncBaseHTTPView( Generic[Request, Response, SubResponse, Context, RootValue], ): schema: BaseSchema - graphiql: Optional[bool] - graphql_ide: Optional[GraphQL_IDE] + graphiql: bool | None + graphql_ide: GraphQL_IDE | None request_adapter_class: Callable[[Request], SyncHTTPRequestAdapter] # Methods that need to be implemented by individual frameworks @@ -56,12 +54,12 @@ def get_sub_response(self, request: Request) -> SubResponse: ... def get_context(self, request: Request, response: SubResponse) -> Context: ... @abc.abstractmethod - def get_root_value(self, request: Request) -> Optional[RootValue]: ... + def get_root_value(self, request: Request) -> RootValue | None: ... @abc.abstractmethod def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: SubResponse, ) -> Response: ... @@ -72,9 +70,9 @@ def execute_operation( self, request: Request, context: Context, - root_value: Optional[RootValue], + root_value: RootValue | None, sub_response: SubResponse, - ) -> Union[ExecutionResult, list[ExecutionResult]]: + ) -> ExecutionResult | list[ExecutionResult]: request_adapter = self.request_adapter_class(request) try: @@ -119,7 +117,7 @@ def execute_single( request_adapter: SyncHTTPRequestAdapter, sub_response: SubResponse, context: Context, - root_value: Optional[RootValue], + root_value: RootValue | None, request_data: GraphQLRequestData, ) -> ExecutionResult: allowed_operation_types = OperationType.from_http(request_adapter.method) @@ -159,7 +157,7 @@ def parse_multipart(self, request: SyncHTTPRequestAdapter) -> dict[str, str]: def parse_http_body( self, request: SyncHTTPRequestAdapter - ) -> Union[GraphQLRequestData, list[GraphQLRequestData]]: + ) -> GraphQLRequestData | list[GraphQLRequestData]: headers = {key.lower(): value for key, value in request.headers.items()} content_type, params = parse_content_type(request.content_type or "") accept = headers.get("accept", "") @@ -233,7 +231,7 @@ def run( self, request: Request, context: Context = UNSET, - root_value: Optional[RootValue] = UNSET, + root_value: RootValue | None = UNSET, ) -> Response: request_adapter = self.request_adapter_class(request) @@ -260,7 +258,7 @@ def run( sub_response=sub_response, ) - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]] + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse] if isinstance(result, list): response_data = [] diff --git a/strawberry/http/types.py b/strawberry/http/types.py index b49ac2bb3f..427e29f0e4 100644 --- a/strawberry/http/types.py +++ b/strawberry/http/types.py @@ -1,12 +1,12 @@ from collections.abc import Mapping -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import TypedDict HTTPMethod = Literal[ "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "TRACE" ] -QueryParams = Mapping[str, Optional[str]] +QueryParams = Mapping[str, str | None] class FormData(TypedDict): diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 10bebfbf10..8fc89459cd 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -9,10 +9,9 @@ TYPE_CHECKING, Any, ClassVar, - Optional, + TypeAlias, TypedDict, TypeGuard, - Union, ) from lia import HTTPException, LitestarRequestAdapter @@ -66,9 +65,9 @@ class BaseContext(Struct, kw_only=True): - request: Optional[Request] = None - websocket: Optional[WebSocket] = None - response: Optional[Response] = None + request: Request | None = None + websocket: WebSocket | None = None + response: Response | None = None class HTTPContextType: @@ -93,9 +92,9 @@ class WebSocketContextDict(TypedDict): socket: WebSocket -MergedContext = Union[ - BaseContext, WebSocketContextDict, HTTPContextDict, dict[str, Any] -] +MergedContext: TypeAlias = ( + BaseContext | WebSocketContextDict | HTTPContextDict | dict[str, Any] +) async def _none_custom_context_getter() -> None: @@ -107,7 +106,7 @@ async def _none_root_value_getter() -> None: async def _context_getter_ws( - custom_context: Optional[Any], socket: WebSocket + custom_context: Any | None, socket: WebSocket ) -> MergedContext: if isinstance(custom_context, BaseContext): custom_context.websocket = socket @@ -129,7 +128,7 @@ def _response_getter() -> Response: async def _context_getter_http( - custom_context: Optional[Any], + custom_context: Any | None, response: Response, request: Request[Any, Any, Any], ) -> MergedContext: @@ -150,9 +149,9 @@ async def _context_getter_http( class GraphQLResource(Struct): - data: Optional[dict[str, object]] - errors: Optional[list[object]] - extensions: Optional[dict[str, object]] + data: dict[str, object] | None + errors: list[object] | None + extensions: dict[str, object] | None class LitestarWebSocketAdapter(AsyncWebSocketAdapter): @@ -217,7 +216,7 @@ class GraphQLController( allow_queries_via_get: bool = True graphiql_allowed_accept: frozenset[str] = frozenset({"text/html", "*/*"}) - graphql_ide: Optional[GraphQL_IDE] = "graphiql" + graphql_ide: GraphQL_IDE | None = "graphiql" connection_init_wait_timeout: timedelta = timedelta(minutes=1) protocols: Sequence[str] = ( GRAPHQL_TRANSPORT_WS_PROTOCOL, @@ -227,18 +226,18 @@ class GraphQLController( keep_alive_interval: float = 1 def is_websocket_request( - self, request: Union[Request, WebSocket] + self, request: Request | WebSocket ) -> TypeGuard[WebSocket]: return isinstance(request, WebSocket) - async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: WebSocket) -> str | None: subprotocols = request.scope["subprotocols"] intersection = set(subprotocols) & set(self.protocols) sorted_intersection = sorted(intersection, key=subprotocols.index) return next(iter(sorted_intersection), None) async def create_websocket_response( - self, request: WebSocket, subprotocol: Optional[str] + self, request: WebSocket, subprotocol: str | None ) -> WebSocket: await request.accept(subprotocols=subprotocol) return request @@ -248,7 +247,7 @@ async def execute_request( request: Request[Any, Any, Any], context: Any, root_value: Any, - ) -> Response[Union[GraphQLResource, str]]: + ) -> Response[GraphQLResource | str]: try: return await self.run( request, @@ -269,7 +268,7 @@ async def render_graphql_ide( def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: Response[bytes], ) -> Response[bytes]: response = Response( @@ -310,7 +309,7 @@ async def handle_http_get( context: Any, root_value: Any, response: Response, - ) -> Response[Union[GraphQLResource, str]]: + ) -> Response[GraphQLResource | str]: self.temporal_response = response return await self.execute_request( @@ -326,7 +325,7 @@ async def handle_http_post( context: Any, root_value: Any, response: Response, - ) -> Response[Union[GraphQLResource, str]]: + ) -> Response[GraphQLResource | str]: self.temporal_response = response return await self.execute_request( @@ -350,14 +349,14 @@ async def websocket_endpoint( async def get_context( self, - request: Union[Request[Any, Any, Any], WebSocket], - response: Union[Response, WebSocket], + request: Request[Any, Any, Any] | WebSocket, + response: Response | WebSocket, ) -> Context: # pragma: no cover msg = "`get_context` is not used by Litestar's controller" raise ValueError(msg) async def get_root_value( - self, request: Union[Request[Any, Any, Any], WebSocket] + self, request: Request[Any, Any, Any] | WebSocket ) -> RootValue | None: # pragma: no cover msg = "`get_root_value` is not used by Litestar's controller" raise ValueError(msg) @@ -369,15 +368,15 @@ async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response: def make_graphql_controller( schema: BaseSchema, path: str = "", - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, # TODO: root typevar - root_value_getter: Optional[AnyCallable] = None, + root_value_getter: AnyCallable | None = None, # TODO: context typevar - context_getter: Optional[AnyCallable] = None, + context_getter: AnyCallable | None = None, subscription_protocols: Sequence[str] = ( GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL, @@ -397,7 +396,7 @@ def make_graphql_controller( schema_: BaseSchema = schema allow_queries_via_get_: bool = allow_queries_via_get - graphql_ide_: Optional[GraphQL_IDE] + graphql_ide_: GraphQL_IDE | None if graphiql is not None: warnings.warn( diff --git a/strawberry/permission.py b/strawberry/permission.py index 70622d94e4..7d10a023f5 100644 --- a/strawberry/permission.py +++ b/strawberry/permission.py @@ -7,8 +7,6 @@ from typing import ( TYPE_CHECKING, Any, - Optional, - Union, ) from strawberry.exceptions import StrawberryGraphQLError @@ -50,18 +48,18 @@ def has_permission(self, source, info, **kwargs): ``` """ - message: Optional[str] = None + message: str | None = None - error_extensions: Optional[GraphQLErrorExtensions] = None + error_extensions: GraphQLErrorExtensions | None = None error_class: type[GraphQLError] = StrawberryGraphQLError - _schema_directive: Optional[object] = None + _schema_directive: object | None = None @abc.abstractmethod def has_permission( self, source: Any, info: Info, **kwargs: Any - ) -> Union[bool, Awaitable[bool]]: + ) -> bool | Awaitable[bool]: """Check if the permission should be accepted. This method should be overridden by the subclasses. diff --git a/strawberry/printer/ast_from_value.py b/strawberry/printer/ast_from_value.py index d4e23294ca..12e56c0c65 100644 --- a/strawberry/printer/ast_from_value.py +++ b/strawberry/printer/ast_from_value.py @@ -3,7 +3,7 @@ import re from collections.abc import Mapping from math import isfinite -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from graphql.language import ( BooleanValueNode, @@ -44,9 +44,7 @@ _re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$") -def ast_from_leaf_type( - serialized: object, type_: Optional[GraphQLInputType] -) -> ValueNode: +def ast_from_leaf_type(serialized: object, type_: GraphQLInputType | None) -> ValueNode: # Others serialize based on their corresponding Python scalar types. if isinstance(serialized, bool): return BooleanValueNode(value=serialized) @@ -86,7 +84,7 @@ def ast_from_leaf_type( ) # pragma: no cover -def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: +def ast_from_value(value: Any, type_: GraphQLInputType) -> ValueNode | None: # custom ast_from_value that allows to also serialize custom scalar that aren't # basic types, namely JSON scalar types diff --git a/strawberry/printer/printer.py b/strawberry/printer/printer.py index 3a520a036a..74152b8cd8 100644 --- a/strawberry/printer/printer.py +++ b/strawberry/printer/printer.py @@ -5,9 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, TypeVar, - Union, cast, overload, ) @@ -79,7 +77,7 @@ def _serialize_dataclasses( @overload def _serialize_dataclasses( - value: Union[list[object], tuple[object]], + value: list[object] | tuple[object], *, name_converter: Callable[[str], str] | None = None, ) -> list[object]: ... @@ -191,7 +189,7 @@ def print_schema_directive( def print_field_directives( - field: Optional[StrawberryField], schema: BaseSchema, *, extras: PrintExtras + field: StrawberryField | None, schema: BaseSchema, *, extras: PrintExtras ) -> str: if not field: return "" @@ -364,7 +362,7 @@ def print_extends(type_: GraphQLObjectType, schema: BaseSchema) -> str: from strawberry.schema.schema_converter import GraphQLCoreConverter strawberry_type = cast( - "Optional[StrawberryObjectDefinition]", + "StrawberryObjectDefinition | None", type_.extensions and type_.extensions.get(GraphQLCoreConverter.DEFINITION_BACKREF), ) @@ -381,7 +379,7 @@ def print_type_directives( from strawberry.schema.schema_converter import GraphQLCoreConverter strawberry_type = cast( - "Optional[StrawberryObjectDefinition]", + "StrawberryObjectDefinition | None", type_.extensions and type_.extensions.get(GraphQLCoreConverter.DEFINITION_BACKREF), ) @@ -535,9 +533,7 @@ def _all_root_names_are_common_names(schema: BaseSchema) -> bool: ) -def print_schema_definition( - schema: BaseSchema, *, extras: PrintExtras -) -> Optional[str]: +def print_schema_definition(schema: BaseSchema, *, extras: PrintExtras) -> str | None: # TODO: add support for description if _all_root_names_are_common_names(schema) and not schema.schema_directives: @@ -559,9 +555,7 @@ def print_schema_definition( return f"schema{directives} {{\n" + "\n".join(operation_types) + "\n}" -def print_directive( - directive: GraphQLDirective, *, schema: BaseSchema -) -> Optional[str]: +def print_directive(directive: GraphQLDirective, *, schema: BaseSchema) -> str | None: strawberry_directive = directive.extensions.get("strawberry-definition") if strawberry_directive is None or ( diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 0450392268..8b898f055e 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from datetime import timedelta from json.decoder import JSONDecodeError -from typing import TYPE_CHECKING, ClassVar, Optional, TypeGuard, Union +from typing import TYPE_CHECKING, ClassVar, TypeGuard, Union from lia import HTTPException, QuartHTTPRequestAdapter @@ -82,8 +82,8 @@ class GraphQLView( def __init__( self, schema: "BaseSchema", - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = True, keep_alive_interval: float = 1, @@ -125,13 +125,11 @@ def create_response( return sub_response async def get_context( - self, request: Union[Request, Websocket], response: Response + self, request: Request | Websocket, response: Response ) -> Context: return {"request": request, "response": response} # type: ignore - async def get_root_value( - self, request: Union[Request, Websocket] - ) -> Optional[RootValue]: + async def get_root_value(self, request: Request | Websocket) -> RootValue | None: return None async def get_sub_response(self, request: Request) -> Response: @@ -165,18 +163,18 @@ async def create_streaming_response( ) def is_websocket_request( - self, request: Union[Request, Websocket] + self, request: Request | Websocket ) -> TypeGuard[Websocket]: return has_websocket_context() - async def pick_websocket_subprotocol(self, request: Websocket) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: Websocket) -> str | None: protocols = request.requested_subprotocols intersection = set(protocols) & set(self.subscription_protocols) sorted_intersection = sorted(intersection, key=protocols.index) return next(iter(sorted_intersection), None) async def create_websocket_response( - self, request: Websocket, subprotocol: Optional[str] + self, request: Websocket, subprotocol: str | None ) -> Response: await request.accept(subprotocol=subprotocol) return Response() diff --git a/strawberry/relay/exceptions.py b/strawberry/relay/exceptions.py index f82aa759f6..2168a82954 100644 --- a/strawberry/relay/exceptions.py +++ b/strawberry/relay/exceptions.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from strawberry.exceptions.exception import StrawberryException from strawberry.exceptions.utils.source_finder import SourceFinder @@ -32,7 +32,7 @@ def __init__(self, message: str, cls: type) -> None: super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.cls is None: return None # pragma: no cover @@ -63,7 +63,7 @@ def __init__(self, field_name: str, cls: type) -> None: super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.cls is None: return None # pragma: no cover @@ -97,7 +97,7 @@ def __init__(self, field_name: str, resolver: StrawberryResolver) -> None: super().__init__(self.message) @cached_property - def exception_source(self) -> Optional[ExceptionSource]: + def exception_source(self) -> ExceptionSource | None: if self.function is None: return None # pragma: no cover diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index 9450e7709e..2b32a6269e 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -20,7 +20,6 @@ Any, ForwardRef, Optional, - Union, cast, get_args, get_origin, @@ -82,14 +81,14 @@ async def resolve_async( def get_node_resolver( self, field: StrawberryField - ) -> Callable[[Info, GlobalID], Union[Node, None, Awaitable[Union[Node, None]]]]: + ) -> Callable[[Info, GlobalID], Node | None | Awaitable[Node | None]]: type_ = field.type is_optional = isinstance(type_, StrawberryOptional) def resolver( info: Info, id: Annotated[GlobalID, argument(description="The ID of the object.")], - ) -> Union[Node, None, Awaitable[Union[Node, None]]]: + ) -> Node | None | Awaitable[Node | None]: node_type = id.resolve_type(info) resolved_node = node_type.resolve_node( id.node_id, @@ -115,7 +114,7 @@ async def resolve() -> Any: def get_node_list_resolver( self, field: StrawberryField - ) -> Callable[[Info, list[GlobalID]], Union[list[Node], Awaitable[list[Node]]]]: + ) -> Callable[[Info, list[GlobalID]], list[Node] | Awaitable[list[Node]]]: type_ = field.type assert isinstance(type_, StrawberryList) is_optional = isinstance(type_.of_type, StrawberryOptional) @@ -125,7 +124,7 @@ def resolver( ids: Annotated[ list[GlobalID], argument(description="The IDs of the objects.") ], - ) -> Union[list[Node], Awaitable[list[Node]]]: + ) -> list[Node] | Awaitable[list[Node]]: nodes_map: defaultdict[type[Node], list[str]] = defaultdict(list) # Store the index of the node in the list of nodes of the same type # so that we can return them in the same order while also supporting @@ -209,7 +208,7 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]: class ConnectionExtension(FieldExtension): connection_type: type[Connection[Node]] - def __init__(self, max_results: Optional[int] = None) -> None: + def __init__(self, max_results: int | None = None) -> None: self.max_results = max_results def apply(self, field: StrawberryField) -> None: @@ -218,7 +217,7 @@ def apply(self, field: StrawberryField) -> None: StrawberryArgument( python_name="before", graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[str]), + type_annotation=StrawberryAnnotation(Optional[str]), # noqa: UP045 description=( "Returns the items in the list that come before the " "specified cursor." @@ -228,7 +227,7 @@ def apply(self, field: StrawberryField) -> None: StrawberryArgument( python_name="after", graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[str]), + type_annotation=StrawberryAnnotation(Optional[str]), # noqa: UP045 description=( "Returns the items in the list that come after the " "specified cursor." @@ -238,14 +237,14 @@ def apply(self, field: StrawberryField) -> None: StrawberryArgument( python_name="first", graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[int]), + type_annotation=StrawberryAnnotation(Optional[int]), # noqa: UP045 description="Returns the first n items from the list.", default=None, ), StrawberryArgument( python_name="last", graphql_name=None, - type_annotation=StrawberryAnnotation(Optional[int]), + type_annotation=StrawberryAnnotation(Optional[int]), # noqa: UP045 description=( "Returns the items in the list that come after the " "specified cursor." @@ -309,10 +308,10 @@ def resolve( source: Any, info: Info, *, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, **kwargs: Any, ) -> Any: assert self.connection_type is not None @@ -332,10 +331,10 @@ async def resolve_async( source: Any, info: Info, *, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, **kwargs: Any, ) -> Any: assert self.connection_type is not None @@ -381,20 +380,20 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField: def connection( - graphql_type: Optional[ConnectionGraphQLType] = None, + graphql_type: ConnectionGraphQLType | None = None, *, - resolver: Optional[_RESOLVER_TYPE[Any]] = None, - name: Optional[str] = None, + resolver: _RESOLVER_TYPE[Any] | None = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), extensions: list[FieldExtension] | None = None, - max_results: Optional[int] = None, + max_results: int | None = None, # This init parameter is used by pyright to determine whether this field # is added in the constructor or not. It is not used to change # any behaviour at the moment. diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py index 95764b5230..f15d51772c 100644 --- a/strawberry/relay/types.py +++ b/strawberry/relay/types.py @@ -20,7 +20,6 @@ ForwardRef, Generic, Literal, - Optional, TypeAlias, TypeVar, Union, @@ -60,12 +59,9 @@ _T = TypeVar("_T") -NodeIterableType: TypeAlias = Union[ - Iterator[_T], - Iterable[_T], - AsyncIterator[_T], - AsyncIterable[_T], -] +NodeIterableType: TypeAlias = ( + Iterator[_T] | Iterable[_T] | AsyncIterator[_T] | AsyncIterable[_T] +) NodeType = TypeVar("NodeType", bound="Node") PREFIX = "arrayconnection" @@ -114,7 +110,7 @@ def __str__(self) -> str: return to_base64(self.type_name, self.node_id) @classmethod - def from_id(cls, value: Union[str, ID]) -> Self: + def from_id(cls, value: str | ID) -> Self: """Create a new GlobalID from parsing the given value. Args: @@ -162,7 +158,7 @@ async def resolve_node( *, required: bool = ..., ensure_type: None = ..., - ) -> Optional[Node]: ... + ) -> Node | None: ... async def resolve_node(self, info, *, required=False, ensure_type=None) -> Any: """Resolve the type name and node id info to the node itself. @@ -272,7 +268,7 @@ def resolve_node_sync( *, required: bool = ..., ensure_type: None = ..., - ) -> Optional[Node]: ... + ) -> Node | None: ... def resolve_node_sync(self, info, *, required=False, ensure_type=None) -> Any: """Resolve the type name and node id info to the node itself. @@ -382,7 +378,7 @@ def resolve_nodes(cls, *, info, node_ids, required=False): ``` """ - _id_attr: ClassVar[Optional[str]] = None + _id_attr: ClassVar[str | None] = None @field(name="id", description="The Globally Unique ID of this object") @classmethod @@ -517,7 +513,7 @@ def resolve_nodes( info: Info, node_ids: Iterable[str], required: Literal[False] = ..., - ) -> AwaitableOrValue[Iterable[Optional[Self]]]: ... + ) -> AwaitableOrValue[Iterable[Self | None]]: ... @overload @classmethod @@ -527,10 +523,7 @@ def resolve_nodes( info: Info, node_ids: Iterable[str], required: bool, - ) -> Union[ - AwaitableOrValue[Iterable[Self]], - AwaitableOrValue[Iterable[Optional[Self]]], - ]: ... + ) -> AwaitableOrValue[Iterable[Self]] | AwaitableOrValue[Iterable[Self | None]]: ... @classmethod def resolve_nodes( @@ -581,7 +574,7 @@ def resolve_node( *, info: Info, required: Literal[False] = ..., - ) -> AwaitableOrValue[Optional[Self]]: ... + ) -> AwaitableOrValue[Self | None]: ... @overload @classmethod @@ -591,7 +584,7 @@ def resolve_node( *, info: Info, required: bool, - ) -> AwaitableOrValue[Optional[Self]]: ... + ) -> AwaitableOrValue[Self | None]: ... @classmethod def resolve_node( @@ -600,7 +593,7 @@ def resolve_node( *, info: Info, required: bool = False, - ) -> AwaitableOrValue[Optional[Self]]: + ) -> AwaitableOrValue[Self | None]: """Resolve a node given its id. This method is a convenience method that calls `resolve_nodes` for @@ -645,10 +638,10 @@ class PageInfo: has_previous_page: bool = field( description="When paginating backwards, are there more items?", ) - start_cursor: Optional[str] = field( + start_cursor: str | None = field( description="When paginating backwards, the cursor to continue.", ) - end_cursor: Optional[str] = field( + end_cursor: str | None = field( description="When paginating forwards, the cursor to continue.", ) @@ -719,11 +712,11 @@ def resolve_connection( nodes: NodeIterableType[NodeType], *, info: Info, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - max_results: Optional[int] = None, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, + max_results: int | None = None, **kwargs: Any, ) -> AwaitableOrValue[Self]: """Resolve a connection from nodes. @@ -771,11 +764,11 @@ def resolve_connection( nodes: NodeIterableType[NodeType], *, info: Info, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - max_results: Optional[int] = None, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, + max_results: int | None = None, **kwargs: Any, ) -> AwaitableOrValue[Self]: """Resolve a connection from the list of nodes. @@ -824,7 +817,7 @@ def resolve_connection( async def resolver() -> Self: try: iterator = cast( - "Union[AsyncIterator[NodeType], AsyncIterable[NodeType]]", + "AsyncIterator[NodeType] | AsyncIterable[NodeType]", cast("Sequence", nodes)[ slice_metadata.start : slice_metadata.overfetch ], @@ -890,7 +883,7 @@ async def resolver() -> Self: try: iterator = cast( - "Union[Iterator[NodeType], Iterable[NodeType]]", + "Iterator[NodeType] | Iterable[NodeType]", cast("Sequence", nodes)[ slice_metadata.start : slice_metadata.overfetch ], diff --git a/strawberry/relay/utils.py b/strawberry/relay/utils.py index 61b5a2450f..eb7c05ed9d 100644 --- a/strawberry/relay/utils.py +++ b/strawberry/relay/utils.py @@ -3,7 +3,7 @@ import base64 import dataclasses import sys -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from typing_extensions import Self, assert_never from strawberry.types.base import StrawberryObjectDefinition @@ -39,7 +39,7 @@ def from_base64(value: str) -> tuple[str, str]: return res[0], res[1] -def to_base64(type_: Union[str, type, StrawberryObjectDefinition], node_id: Any) -> str: +def to_base64(type_: str | type | StrawberryObjectDefinition, node_id: Any) -> str: """Encode the type name and node id to a base64 string. Args: diff --git a/strawberry/sanic/utils.py b/strawberry/sanic/utils.py index b3afb72f56..07d2a62e25 100644 --- a/strawberry/sanic/utils.py +++ b/strawberry/sanic/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from sanic.request import File, Request @@ -21,12 +21,12 @@ def convert_request_to_files_dict(request: Request) -> dict[str, Any]: Note that the dictionary entries are lists. """ - request_files = cast("Optional[dict[str, list[File]]]", request.files) + request_files = cast("dict[str, list[File]] | None", request.files) if not request_files: return {} - files_dict: dict[str, Union[File, list[File]]] = {} + files_dict: dict[str, File | list[File]] = {} for field_name, file_list in request_files.items(): assert len(file_list) == 1 diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index b6728b5d04..ef43363f26 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -5,9 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, TypeGuard, - Union, ) from lia import HTTPException, SanicHTTPRequestAdapter @@ -65,11 +63,11 @@ class GraphQLView( def __init__( self, schema: BaseSchema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, - json_encoder: Optional[type[json.JSONEncoder]] = None, - json_dumps_params: Optional[dict[str, Any]] = None, + json_encoder: type[json.JSONEncoder] | None = None, + json_dumps_params: dict[str, Any] | None = None, multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema @@ -104,7 +102,7 @@ def __init__( else: self.graphql_ide = graphql_ide - async def get_root_value(self, request: Request) -> Optional[RootValue]: + async def get_root_value(self, request: Request) -> RootValue | None: return None async def get_context( @@ -120,7 +118,7 @@ async def get_sub_response(self, request: Request) -> TemporalResponse: def create_response( self, - response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse], sub_response: TemporalResponse, ) -> HTTPResponse: status_code = sub_response.status_code @@ -179,11 +177,11 @@ async def create_streaming_response( def is_websocket_request(self, request: Request) -> TypeGuard[Request]: return False - async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + async def pick_websocket_subprotocol(self, request: Request) -> str | None: raise NotImplementedError async def create_websocket_response( - self, request: Request, subprotocol: Optional[str] + self, request: Request, subprotocol: str | None ) -> TemporalResponse: raise NotImplementedError diff --git a/strawberry/scalars.py b/strawberry/scalars.py index 4639e5ac59..13dc93771c 100644 --- a/strawberry/scalars.py +++ b/strawberry/scalars.py @@ -1,7 +1,7 @@ from __future__ import annotations import base64 -from typing import TYPE_CHECKING, Any, NewType, Union +from typing import TYPE_CHECKING, Any, NewType from strawberry.types.scalar import scalar @@ -59,7 +59,7 @@ def is_scalar( annotation: Any, - scalar_registry: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], ) -> bool: if annotation in scalar_registry: return True diff --git a/strawberry/schema/_graphql_core.py b/strawberry/schema/_graphql_core.py index 5d8b921982..a2e9bd2d51 100644 --- a/strawberry/schema/_graphql_core.py +++ b/strawberry/schema/_graphql_core.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import TypeAlias, Union from graphql.execution import ExecutionContext as GraphQLExecutionContext from graphql.execution import ExecutionResult as OriginalGraphQLExecutionResult @@ -24,9 +24,9 @@ GraphQLStreamDirective, ) - GraphQLExecutionResult = Union[ - OriginalGraphQLExecutionResult, InitialIncrementalExecutionResult - ] + GraphQLExecutionResult: TypeAlias = ( + OriginalGraphQLExecutionResult | InitialIncrementalExecutionResult + ) except ImportError: GraphQLIncrementalExecutionResults = type(None) @@ -37,7 +37,7 @@ # TODO: give this a better name, maybe also a better place -ResultType = Union[ +ResultType = Union[ # noqa: UP007 OriginalGraphQLExecutionResult, GraphQLIncrementalExecutionResults, ExecutionResult, diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index ea9a44f5ee..89d0d6c18f 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from typing_extensions import Protocol from strawberry.utils.logging import StrawberryLogger @@ -34,33 +34,33 @@ class BaseSchema(Protocol): config: StrawberryConfig schema_converter: GraphQLCoreConverter query: type[WithStrawberryObjectDefinition] - mutation: Optional[type[WithStrawberryObjectDefinition]] - subscription: Optional[type[WithStrawberryObjectDefinition]] + mutation: type[WithStrawberryObjectDefinition] | None + subscription: type[WithStrawberryObjectDefinition] | None schema_directives: list[object] @abstractmethod async def execute( self, - query: Optional[str], - variable_values: Optional[dict[str, Any]] = None, - context_value: Optional[Any] = None, - root_value: Optional[Any] = None, - operation_name: Optional[str] = None, - allowed_operation_types: Optional[Iterable[OperationType]] = None, - operation_extensions: Optional[dict[str, Any]] = None, + query: str | None, + variable_values: dict[str, Any] | None = None, + context_value: Any | None = None, + root_value: Any | None = None, + operation_name: str | None = None, + allowed_operation_types: Iterable[OperationType] | None = None, + operation_extensions: dict[str, Any] | None = None, ) -> ExecutionResult: raise NotImplementedError @abstractmethod def execute_sync( self, - query: Optional[str], - variable_values: Optional[dict[str, Any]] = None, - context_value: Optional[Any] = None, - root_value: Optional[Any] = None, - operation_name: Optional[str] = None, - allowed_operation_types: Optional[Iterable[OperationType]] = None, - operation_extensions: Optional[dict[str, Any]] = None, + query: str | None, + variable_values: dict[str, Any] | None = None, + context_value: Any | None = None, + root_value: Any | None = None, + operation_name: str | None = None, + allowed_operation_types: Iterable[OperationType] | None = None, + operation_extensions: dict[str, Any] | None = None, ) -> ExecutionResult: raise NotImplementedError @@ -68,29 +68,28 @@ def execute_sync( async def subscribe( self, query: str, - variable_values: Optional[dict[str, Any]] = None, - context_value: Optional[Any] = None, - root_value: Optional[Any] = None, - operation_name: Optional[str] = None, - operation_extensions: Optional[dict[str, Any]] = None, + variable_values: dict[str, Any] | None = None, + context_value: Any | None = None, + root_value: Any | None = None, + operation_name: str | None = None, + operation_extensions: dict[str, Any] | None = None, ) -> SubscriptionResult: raise NotImplementedError @abstractmethod def get_type_by_name( self, name: str - ) -> Optional[ - Union[ - StrawberryObjectDefinition, - ScalarDefinition, - EnumDefinition, - StrawberryUnion, - ] - ]: + ) -> ( + StrawberryObjectDefinition + | ScalarDefinition + | EnumDefinition + | StrawberryUnion + | None + ): raise NotImplementedError @abstractmethod - def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirective]: + def get_directive_by_name(self, graphql_name: str) -> StrawberryDirective | None: raise NotImplementedError @abstractmethod @@ -108,7 +107,7 @@ def remove_field_suggestion(error: GraphQLError) -> None: def _process_errors( self, errors: list[GraphQLError], - execution_context: Optional[ExecutionContext] = None, + execution_context: ExecutionContext | None = None, ) -> None: if self.config.disable_field_suggestions: for error in errors: @@ -119,7 +118,7 @@ def _process_errors( def process_errors( self, errors: list[GraphQLError], - execution_context: Optional[ExecutionContext] = None, + execution_context: ExecutionContext | None = None, ) -> None: for error in errors: StrawberryLogger.error(error, execution_context) diff --git a/strawberry/schema/compat.py b/strawberry/schema/compat.py index 07bdb03e28..257ebe06b6 100644 --- a/strawberry/schema/compat.py +++ b/strawberry/schema/compat.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from strawberry.scalars import is_scalar as is_strawberry_scalar from strawberry.types.base import StrawberryType, has_object_definition @@ -16,35 +16,35 @@ from strawberry.types.scalar import ScalarDefinition, ScalarWrapper -def is_input_type(type_: Union[StrawberryType, type]) -> TypeGuard[type]: +def is_input_type(type_: StrawberryType | type) -> TypeGuard[type]: if not has_object_definition(type_): return False return type_.__strawberry_definition__.is_input -def is_interface_type(type_: Union[StrawberryType, type]) -> TypeGuard[type]: +def is_interface_type(type_: StrawberryType | type) -> TypeGuard[type]: if not has_object_definition(type_): return False return type_.__strawberry_definition__.is_interface def is_scalar( - type_: Union[StrawberryType, type], - scalar_registry: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + type_: StrawberryType | type, + scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], ) -> TypeGuard[type]: return is_strawberry_scalar(type_, scalar_registry) -def is_enum(type_: Union[StrawberryType, type]) -> TypeGuard[type]: +def is_enum(type_: StrawberryType | type) -> TypeGuard[type]: return hasattr(type_, "_enum_definition") -def is_schema_directive(type_: Union[StrawberryType, type]) -> TypeGuard[type]: +def is_schema_directive(type_: StrawberryType | type) -> TypeGuard[type]: return hasattr(type_, "__strawberry_directive__") # TODO: do we still need this? -def is_graphql_generic(type_: Union[StrawberryType, type]) -> bool: +def is_graphql_generic(type_: StrawberryType | type) -> bool: if has_object_definition(type_): return type_.__strawberry_definition__.is_graphql_generic diff --git a/strawberry/schema/config.py b/strawberry/schema/config.py index eb5c758f3c..e484b0988e 100644 --- a/strawberry/schema/config.py +++ b/strawberry/schema/config.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import InitVar, dataclass, field -from typing import TYPE_CHECKING, Any, Optional, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict from strawberry.types.info import Info @@ -26,7 +26,7 @@ class StrawberryConfig: info_class: type[Info] = Info enable_experimental_incremental_execution: bool = False _unsafe_disable_same_type_validation: bool = False - batching_config: Optional[BatchingConfig] = None + batching_config: BatchingConfig | None = None def __post_init__( self, diff --git a/strawberry/schema/exceptions.py b/strawberry/schema/exceptions.py index 5331cbae41..cd860a1a2b 100644 --- a/strawberry/schema/exceptions.py +++ b/strawberry/schema/exceptions.py @@ -1,12 +1,10 @@ -from typing import Optional - from strawberry.types.graphql import OperationType class CannotGetOperationTypeError(Exception): """Internal error raised when we cannot get the operation type from a GraphQL document.""" - def __init__(self, operation_name: Optional[str]) -> None: + def __init__(self, operation_name: str | None) -> None: self.operation_name = operation_name def as_http_error_reason(self) -> str: diff --git a/strawberry/schema/name_converter.py b/strawberry/schema/name_converter.py index afba5e1c7e..599fc2bf90 100644 --- a/strawberry/schema/name_converter.py +++ b/strawberry/schema/name_converter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, cast from typing_extensions import Protocol from strawberry.directive import StrawberryDirective @@ -26,7 +26,7 @@ class HasGraphQLName(Protocol): python_name: str - graphql_name: Optional[str] + graphql_name: str | None class NameConverter: @@ -41,7 +41,7 @@ def apply_naming_config(self, name: str) -> str: def from_type( self, - type_: Union[StrawberryType, StrawberryDirective], + type_: StrawberryType | StrawberryDirective, ) -> str: if isinstance(type_, (StrawberryDirective, StrawberrySchemaDirective)): return self.from_directive(type_) @@ -85,7 +85,7 @@ def from_enum_value(self, enum: EnumDefinition, enum_value: EnumValue) -> str: return enum_value.name def from_directive( - self, directive: Union[StrawberryDirective, StrawberrySchemaDirective] + self, directive: StrawberryDirective | StrawberrySchemaDirective ) -> str: name = self.get_graphql_name(directive) @@ -134,7 +134,7 @@ def from_union(self, union: StrawberryUnion) -> str: def from_generic( self, generic_type: StrawberryObjectDefinition, - types: list[Union[StrawberryType, type]], + types: list[StrawberryType | type], ) -> str: generic_type_name = generic_type.name @@ -146,7 +146,7 @@ def from_generic( return "".join(names) + generic_type_name - def get_name_from_type(self, type_: Union[StrawberryType, type]) -> str: + def get_name_from_type(self, type_: StrawberryType | type) -> str: type_ = eval_type(type_) if isinstance(type_, LazyType): diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 4c5a66f225..48259566f4 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -9,8 +9,6 @@ TYPE_CHECKING, Any, NamedTuple, - Optional, - Union, cast, ) @@ -97,13 +95,12 @@ from strawberry.types.union import StrawberryUnion SubscriptionResult: TypeAlias = AsyncGenerator[ - Union[PreExecutionError, ExecutionResult], None + PreExecutionError | ExecutionResult, None ] -OriginSubscriptionResult = Union[ - OriginalExecutionResult, - AsyncIterator[OriginalExecutionResult], -] +OriginSubscriptionResult: TypeAlias = ( + OriginalExecutionResult | AsyncIterator[OriginalExecutionResult] +) DEFAULT_ALLOWED_OPERATION_TYPES = { @@ -112,7 +109,7 @@ OperationType.SUBSCRIPTION, } ProcessErrors: TypeAlias = ( - "Callable[[list[GraphQLError], Optional[ExecutionContext]], None]" + "Callable[[list[GraphQLError], ExecutionContext | None], None]" ) @@ -149,7 +146,7 @@ def _run_validation(execution_context: ExecutionContext) -> None: ) -def _coerce_error(error: Union[GraphQLError, Exception]) -> GraphQLError: +def _coerce_error(error: GraphQLError | Exception) -> GraphQLError: if isinstance(error, GraphQLError): return error return GraphQLError(str(error), original_error=error) @@ -211,16 +208,16 @@ def __init__( # TODO: can we make sure we only allow to pass # something that has been decorated? query: type, - mutation: Optional[type] = None, - subscription: Optional[type] = None, + mutation: type | None = None, + subscription: type | None = None, directives: Iterable[StrawberryDirective] = (), - types: Iterable[Union[type, StrawberryType]] = (), - extensions: Iterable[Union[type[SchemaExtension], SchemaExtension]] = (), - execution_context_class: Optional[type[GraphQLExecutionContext]] = None, - config: Optional[StrawberryConfig] = None, - scalar_overrides: Optional[ - Mapping[object, Union[type, ScalarWrapper, ScalarDefinition]], - ] = None, + types: Iterable[type | StrawberryType] = (), + extensions: Iterable[type[SchemaExtension] | SchemaExtension] = (), + execution_context_class: type[GraphQLExecutionContext] | None = None, + config: StrawberryConfig | None = None, + scalar_overrides: ( + Mapping[object, type | ScalarWrapper | ScalarDefinition] | None + ) = None, schema_directives: Iterable[object] = (), ) -> None: """Default Schema to be used in a Strawberry application. @@ -397,7 +394,7 @@ def create_extensions_runner( ) def _get_custom_context_kwargs( - self, operation_extensions: Optional[dict[str, Any]] = None + self, operation_extensions: dict[str, Any] | None = None ) -> dict[str, Any]: if not IS_GQL_33: return {} @@ -416,13 +413,13 @@ def _get_middleware_manager( def _create_execution_context( self, - query: Optional[str], + query: str | None, allowed_operation_types: Iterable[OperationType], - variable_values: Optional[dict[str, Any]] = None, - context_value: Optional[Any] = None, - root_value: Optional[Any] = None, - operation_name: Optional[str] = None, - operation_extensions: Optional[dict[str, Any]] = None, + variable_values: dict[str, Any] | None = None, + context_value: Any | None = None, + root_value: Any | None = None, + operation_name: str | None = None, + operation_extensions: dict[str, Any] | None = None, ) -> ExecutionContext: return ExecutionContext( query=query, @@ -438,14 +435,13 @@ def _create_execution_context( @lru_cache def get_type_by_name( self, name: str - ) -> Optional[ - Union[ - StrawberryObjectDefinition, - ScalarDefinition, - EnumDefinition, - StrawberryUnion, - ] - ]: + ) -> ( + StrawberryObjectDefinition + | ScalarDefinition + | EnumDefinition + | StrawberryUnion + | None + ): # TODO: respect auto_camel_case if name in self.schema_converter.type_map: return self.schema_converter.type_map[name].definition @@ -454,7 +450,7 @@ def get_type_by_name( def get_field_for_type( self, field_name: str, type_name: str - ) -> Optional[StrawberryField]: + ) -> StrawberryField | None: type_ = self.get_type_by_name(type_name) if not type_: @@ -472,7 +468,7 @@ def get_field_for_type( ) @lru_cache - def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirective]: + def get_directive_by_name(self, graphql_name: str) -> StrawberryDirective | None: return next( ( directive @@ -489,7 +485,7 @@ def get_fields( async def _parse_and_validate_async( self, context: ExecutionContext, extensions_runner: SchemaExtensionsRunner - ) -> Optional[PreExecutionError]: + ) -> PreExecutionError | None: if not context.query: raise MissingQueryError @@ -552,13 +548,13 @@ async def _handle_execution_result( async def execute( self, - query: Optional[str], - variable_values: Optional[dict[str, Any]] = None, - context_value: Optional[Any] = None, - root_value: Optional[Any] = None, - operation_name: Optional[str] = None, - allowed_operation_types: Optional[Iterable[OperationType]] = None, - operation_extensions: Optional[dict[str, Any]] = None, + query: str | None, + variable_values: dict[str, Any] | None = None, + context_value: Any | None = None, + root_value: Any | None = None, + operation_name: str | None = None, + allowed_operation_types: Iterable[OperationType] | None = None, + operation_extensions: dict[str, Any] | None = None, ) -> ExecutionResult: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES @@ -658,13 +654,13 @@ async def execute( def execute_sync( self, - query: Optional[str], - variable_values: Optional[dict[str, Any]] = None, - context_value: Optional[Any] = None, - root_value: Optional[Any] = None, - operation_name: Optional[str] = None, - allowed_operation_types: Optional[Iterable[OperationType]] = None, - operation_extensions: Optional[dict[str, Any]] = None, + query: str | None, + variable_values: dict[str, Any] | None = None, + context_value: Any | None = None, + root_value: Any | None = None, + operation_name: str | None = None, + allowed_operation_types: Iterable[OperationType] | None = None, + operation_extensions: dict[str, Any] | None = None, ) -> ExecutionResult: if allowed_operation_types is None: allowed_operation_types = DEFAULT_ALLOWED_OPERATION_TYPES @@ -804,7 +800,7 @@ async def _subscribe( extensions_runner: SchemaExtensionsRunner, middleware_manager: MiddlewareManager, execution_context_class: type[GraphQLExecutionContext] | None = None, - operation_extensions: Optional[dict[str, Any]] = None, + operation_extensions: dict[str, Any] | None = None, ) -> AsyncGenerator[ExecutionResult, None]: async with extensions_runner.operation(): if initial_error := await self._parse_and_validate_async( @@ -883,12 +879,12 @@ async def _subscribe( async def subscribe( self, - query: Optional[str], - variable_values: Optional[dict[str, Any]] = None, - context_value: Optional[Any] = None, - root_value: Optional[Any] = None, - operation_name: Optional[str] = None, - operation_extensions: Optional[dict[str, Any]] = None, + query: str | None, + variable_values: dict[str, Any] | None = None, + context_value: Any | None = None, + root_value: Any | None = None, + operation_name: str | None = None, + operation_extensions: dict[str, Any] | None = None, ) -> SubscriptionResult: execution_context = self._create_execution_context( query=query, diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index df8b0c99bd..6ce41c450c 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -9,9 +9,7 @@ Annotated, Any, Generic, - Optional, TypeVar, - Union, cast, ) from typing_extensions import Protocol @@ -101,7 +99,7 @@ FieldType = TypeVar( "FieldType", - bound=Union[GraphQLField, GraphQLInputField], + bound=GraphQLField | GraphQLInputField, covariant=True, ) @@ -111,7 +109,7 @@ def __call__( # pragma: no cover self, field: StrawberryField, *, - type_definition: Optional[StrawberryObjectDefinition] = None, + type_definition: StrawberryObjectDefinition | None = None, ) -> FieldType: ... @@ -181,7 +179,7 @@ def parse_value(self, input_value: str) -> Any: return self.wrapped_cls(super().parse_value(input_value)) def parse_literal( - self, value_node: ValueNode, _variables: Optional[dict[str, Any]] = None + self, value_node: ValueNode, _variables: dict[str, Any] | None = None ) -> Any: return self.wrapped_cls(super().parse_literal(value_node, _variables)) @@ -193,7 +191,7 @@ def get_arguments( info: Info, kwargs: Any, config: StrawberryConfig, - scalar_registry: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], ) -> tuple[list[Any], dict[str, Any]]: # TODO: An extension might have changed the resolver arguments, # but we need them here since we are calling it. @@ -250,7 +248,7 @@ class GraphQLCoreConverter: def __init__( self, config: StrawberryConfig, - scalar_overrides: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + scalar_overrides: Mapping[object, ScalarWrapper | ScalarDefinition], get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]], ) -> None: self.type_map: dict[str, ConcreteType] = {} @@ -260,8 +258,8 @@ def __init__( def _get_scalar_registry( self, - scalar_overrides: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], - ) -> Mapping[object, Union[ScalarWrapper, ScalarDefinition]]: + scalar_overrides: Mapping[object, ScalarWrapper | ScalarDefinition], + ) -> Mapping[object, ScalarWrapper | ScalarDefinition]: scalar_registry = {**DEFAULT_SCALAR_REGISTRY} global_id_name = "GlobalID" if self.config.relay_use_legacy_global_id else "ID" @@ -408,7 +406,7 @@ def from_field( self, field: StrawberryField, *, - type_definition: Optional[StrawberryObjectDefinition] = None, + type_definition: StrawberryObjectDefinition | None = None, ) -> GraphQLField: # self.from_resolver needs to be called before accessing field.type because # in there a field extension might want to change the type during its apply @@ -446,7 +444,7 @@ def from_input_field( self, field: StrawberryField, *, - type_definition: Optional[StrawberryObjectDefinition] = None, + type_definition: StrawberryObjectDefinition | None = None, ) -> GraphQLInputField: field_type = cast( "GraphQLInputType", @@ -557,14 +555,14 @@ def from_interface( def _get_resolve_type() -> Callable[ [Any, GraphQLResolveInfo, GraphQLAbstractType], - Union[Awaitable[Optional[str]], str, None], + Awaitable[str | None] | str | None, ]: if interface.resolve_type: return interface.resolve_type def resolve_type( obj: Any, info: GraphQLResolveInfo, abstract_type: GraphQLAbstractType - ) -> Union[Awaitable[Optional[str]], str, None]: + ) -> Awaitable[str | None] | str | None: if isinstance(obj, interface.origin): type_definition = get_object_definition(obj, strict=True) @@ -579,7 +577,7 @@ def resolve_type( # all the types in the schema, but we should probably # optimize this - return_type: Optional[GraphQLType] = None + return_type: GraphQLType | None = None for possible_concrete_type in self.type_map.values(): possible_type = possible_concrete_type.definition @@ -640,7 +638,7 @@ def from_object(self, object_type: StrawberryObjectDefinition) -> GraphQLObjectT assert isinstance(graphql_object_type, GraphQLObjectType) # For mypy return graphql_object_type - def _get_is_type_of() -> Optional[Callable[[Any, GraphQLResolveInfo], bool]]: + def _get_is_type_of() -> Callable[[Any, GraphQLResolveInfo], bool] | None: if object_type.is_type_of: return object_type.is_type_of @@ -850,8 +848,8 @@ def from_scalar(self, scalar: type) -> GraphQLScalarType: return implementation def from_maybe_optional( - self, type_: Union[StrawberryType, type] - ) -> Union[GraphQLNullableType, GraphQLNonNull]: + self, type_: StrawberryType | type + ) -> GraphQLNullableType | GraphQLNonNull: NoneType = type(None) if type_ is None or type_ is NoneType: return self.from_type(type_) @@ -866,7 +864,7 @@ def from_maybe_optional( return self.from_type(type_.of_type) return GraphQLNonNull(self.from_type(type_)) - def from_type(self, type_: Union[StrawberryType, type]) -> GraphQLNullableType: + def from_type(self, type_: StrawberryType | type) -> GraphQLNullableType: if compat.is_graphql_generic(type_): raise MissingTypesForGenericError(type_) @@ -959,7 +957,7 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType: def _get_is_type_of( self, object_type: StrawberryObjectDefinition, - ) -> Optional[Callable[[Any, GraphQLResolveInfo], bool]]: + ) -> Callable[[Any, GraphQLResolveInfo], bool] | None: if object_type.is_type_of: return object_type.is_type_of diff --git a/strawberry/schema/types/concrete_type.py b/strawberry/schema/types/concrete_type.py index 6a421c26a3..4c8d5857f4 100644 --- a/strawberry/schema/types/concrete_type.py +++ b/strawberry/schema/types/concrete_type.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, TypeAlias from graphql import GraphQLField, GraphQLInputField, GraphQLType @@ -11,14 +11,14 @@ from strawberry.types.scalar import ScalarDefinition from strawberry.types.union import StrawberryUnion -Field = Union[GraphQLInputField, GraphQLField] +Field: TypeAlias = GraphQLInputField | GraphQLField @dataclasses.dataclass class ConcreteType: - definition: Union[ - StrawberryObjectDefinition, EnumDefinition, ScalarDefinition, StrawberryUnion - ] + definition: ( + StrawberryObjectDefinition | EnumDefinition | ScalarDefinition | StrawberryUnion + ) implementation: GraphQLType diff --git a/strawberry/schema_codegen/__init__.py b/strawberry/schema_codegen/__init__.py index 01eab3d533..964d40c216 100644 --- a/strawberry/schema_codegen/__init__.py +++ b/strawberry/schema_codegen/__init__.py @@ -4,7 +4,7 @@ import keyword from collections import defaultdict from graphlib import TopologicalSorter -from typing import TYPE_CHECKING, TypeAlias, Union +from typing import TYPE_CHECKING, TypeAlias from typing_extensions import Protocol import libcst as cst @@ -256,7 +256,7 @@ def _get_field( ) -ArgumentValue: TypeAlias = Union[str, bool, list["ArgumentValue"]] +ArgumentValue: TypeAlias = str | bool | list["ArgumentValue"] def _get_argument_value(argument_value: ConstValueNode) -> ArgumentValue: diff --git a/strawberry/schema_directive.py b/strawberry/schema_directive.py index a3db4f2776..03514b9d8c 100644 --- a/strawberry/schema_directive.py +++ b/strawberry/schema_directive.py @@ -1,7 +1,7 @@ import dataclasses from collections.abc import Callable from enum import Enum -from typing import Optional, TypeVar +from typing import TypeVar from typing_extensions import dataclass_transform from strawberry.types.field import StrawberryField, field @@ -28,13 +28,13 @@ class Location(Enum): @dataclasses.dataclass class StrawberrySchemaDirective: python_name: str - graphql_name: Optional[str] + graphql_name: str | None locations: list[Location] fields: list["StrawberryField"] - description: Optional[str] = None + description: str | None = None repeatable: bool = False print_definition: bool = True - origin: Optional[type] = None + origin: type | None = None T = TypeVar("T", bound=type) @@ -48,8 +48,8 @@ class StrawberrySchemaDirective: def schema_directive( *, locations: list[Location], - description: Optional[str] = None, - name: Optional[str] = None, + description: str | None = None, + name: str | None = None, repeatable: bool = False, print_definition: bool = True, ) -> Callable[[T], T]: diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 5c10cb319c..402577a5b9 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -7,7 +7,6 @@ TYPE_CHECKING, Any, Generic, - Optional, cast, ) @@ -52,7 +51,7 @@ def __init__( view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue], websocket: AsyncWebSocketAdapter, context: Context, - root_value: Optional[RootValue], + root_value: RootValue | None, schema: BaseSchema, connection_init_wait_timeout: timedelta, ) -> None: @@ -62,7 +61,7 @@ def __init__( self.root_value = root_value self.schema = schema self.connection_init_wait_timeout = connection_init_wait_timeout - self.connection_init_timeout_task: Optional[asyncio.Task] = None + self.connection_init_timeout_task: asyncio.Task | None = None self.connection_init_received = False self.connection_acknowledged = False self.connection_timed_out = False @@ -361,8 +360,8 @@ def __init__( id: str, operation_type: OperationType, query: str, - variables: Optional[dict[str, object]], - operation_name: Optional[str], + variables: dict[str, object] | None, + operation_name: str | None, ) -> None: self.handler = handler self.id = id @@ -371,7 +370,7 @@ def __init__( self.variables = variables self.operation_name = operation_name self.completed = False - self.task: Optional[asyncio.Task] = None + self.task: asyncio.Task | None = None async def send_operation_message(self, message: Message) -> None: if self.completed: diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py index d28867cdd1..ea052437ba 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py @@ -1,4 +1,4 @@ -from typing import Literal, TypedDict, Union +from typing import Literal, TypeAlias, TypedDict from typing_extensions import NotRequired from graphql import GraphQLFormattedError @@ -8,35 +8,35 @@ class ConnectionInitMessage(TypedDict): """Direction: Client -> Server.""" type: Literal["connection_init"] - payload: NotRequired[Union[dict[str, object], None]] + payload: NotRequired[dict[str, object] | None] class ConnectionAckMessage(TypedDict): """Direction: Server -> Client.""" type: Literal["connection_ack"] - payload: NotRequired[Union[dict[str, object], None]] + payload: NotRequired[dict[str, object] | None] class PingMessage(TypedDict): """Direction: bidirectional.""" type: Literal["ping"] - payload: NotRequired[Union[dict[str, object], None]] + payload: NotRequired[dict[str, object] | None] class PongMessage(TypedDict): """Direction: bidirectional.""" type: Literal["pong"] - payload: NotRequired[Union[dict[str, object], None]] + payload: NotRequired[dict[str, object] | None] class SubscribeMessagePayload(TypedDict): - operationName: NotRequired[Union[str, None]] + operationName: NotRequired[str | None] query: str - variables: NotRequired[Union[dict[str, object], None]] - extensions: NotRequired[Union[dict[str, object], None]] + variables: NotRequired[dict[str, object] | None] + extensions: NotRequired[dict[str, object] | None] class SubscribeMessage(TypedDict): @@ -49,7 +49,7 @@ class SubscribeMessage(TypedDict): class NextMessagePayload(TypedDict): errors: NotRequired[list[GraphQLFormattedError]] - data: NotRequired[Union[dict[str, object], None]] + data: NotRequired[dict[str, object] | None] extensions: NotRequired[dict[str, object]] @@ -76,16 +76,16 @@ class CompleteMessage(TypedDict): type: Literal["complete"] -Message = Union[ - ConnectionInitMessage, - ConnectionAckMessage, - PingMessage, - PongMessage, - SubscribeMessage, - NextMessage, - ErrorMessage, - CompleteMessage, -] +Message: TypeAlias = ( + ConnectionInitMessage + | ConnectionAckMessage + | PingMessage + | PongMessage + | SubscribeMessage + | NextMessage + | ErrorMessage + | CompleteMessage +) __all__ = [ diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 0403507fc9..21979ff23d 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -7,7 +7,6 @@ TYPE_CHECKING, Any, Generic, - Optional, cast, ) @@ -41,10 +40,10 @@ def __init__( view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue], websocket: AsyncWebSocketAdapter, context: Context, - root_value: Optional[RootValue], + root_value: RootValue | None, schema: BaseSchema, keep_alive: bool, - keep_alive_interval: Optional[float], + keep_alive_interval: float | None, ) -> None: self.view = view self.websocket = websocket @@ -53,7 +52,7 @@ def __init__( self.schema = schema self.keep_alive = keep_alive self.keep_alive_interval = keep_alive_interval - self.keep_alive_task: Optional[asyncio.Task] = None + self.keep_alive_task: asyncio.Task | None = None self.subscriptions: dict[str, AsyncGenerator] = {} self.tasks: dict[str, asyncio.Task] = {} @@ -155,8 +154,8 @@ async def handle_async_results( self, operation_id: str, query: str, - operation_name: Optional[str], - variables: Optional[dict[str, object]], + operation_name: str | None, + variables: dict[str, object] | None, ) -> None: try: result_source = await self.schema.subscribe( diff --git a/strawberry/subscriptions/protocols/graphql_ws/types.py b/strawberry/subscriptions/protocols/graphql_ws/types.py index 217cd2612a..937695abd2 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_ws/types.py @@ -1,4 +1,4 @@ -from typing import Literal, TypedDict, Union +from typing import Literal, TypeAlias, TypedDict from typing_extensions import NotRequired from graphql import GraphQLFormattedError @@ -69,18 +69,18 @@ class ConnectionKeepAliveMessage(TypedDict): type: Literal["ka"] -OperationMessage = Union[ - ConnectionInitMessage, - StartMessage, - StopMessage, - ConnectionTerminateMessage, - ConnectionErrorMessage, - ConnectionAckMessage, - DataMessage, - ErrorMessage, - CompleteMessage, - ConnectionKeepAliveMessage, -] +OperationMessage: TypeAlias = ( + ConnectionInitMessage + | StartMessage + | StopMessage + | ConnectionTerminateMessage + | ConnectionErrorMessage + | ConnectionAckMessage + | DataMessage + | ErrorMessage + | CompleteMessage + | ConnectionKeepAliveMessage +) __all__ = [ diff --git a/strawberry/test/client.py b/strawberry/test/client.py index f4176dc479..78d303be55 100644 --- a/strawberry/test/client.py +++ b/strawberry/test/client.py @@ -4,7 +4,7 @@ import warnings from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal from typing_extensions import TypedDict if TYPE_CHECKING: @@ -15,14 +15,14 @@ @dataclass class Response: - errors: Optional[list[GraphQLFormattedError]] - data: Optional[dict[str, object]] - extensions: Optional[dict[str, object]] + errors: list[GraphQLFormattedError] | None + data: dict[str, object] | None + extensions: dict[str, object] | None class Body(TypedDict, total=False): query: str - variables: Optional[dict[str, object]] + variables: dict[str, object] | None class BaseGraphQLTestClient(ABC): @@ -37,12 +37,12 @@ def __init__( def query( self, query: str, - variables: Optional[dict[str, Mapping]] = None, - headers: Optional[dict[str, object]] = None, - asserts_errors: Optional[bool] = None, - files: Optional[dict[str, object]] = None, - assert_no_errors: Optional[bool] = True, - ) -> Union[Coroutine[Any, Any, Response], Response]: + variables: dict[str, Mapping] | None = None, + headers: dict[str, object] | None = None, + asserts_errors: bool | None = None, + files: dict[str, object] | None = None, + assert_no_errors: bool | None = True, + ) -> Coroutine[Any, Any, Response] | Response: body = self._build_body(query, variables, files) resp = self.request(body, headers, files) @@ -74,16 +74,16 @@ def query( def request( self, body: dict[str, object], - headers: Optional[dict[str, object]] = None, - files: Optional[dict[str, object]] = None, + headers: dict[str, object] | None = None, + files: dict[str, object] | None = None, ) -> Any: raise NotImplementedError def _build_body( self, query: str, - variables: Optional[dict[str, Mapping]] = None, - files: Optional[dict[str, object]] = None, + variables: dict[str, Mapping] | None = None, + files: dict[str, object] | None = None, ) -> dict[str, object]: body: dict[str, object] = {"query": query} diff --git a/strawberry/tools/create_type.py b/strawberry/tools/create_type.py index ff46321ebd..33e1e4f4be 100644 --- a/strawberry/tools/create_type.py +++ b/strawberry/tools/create_type.py @@ -1,6 +1,5 @@ import types from collections.abc import Sequence -from typing import Optional import strawberry from strawberry.types.field import StrawberryField @@ -11,8 +10,8 @@ def create_type( fields: list[StrawberryField], is_input: bool = False, is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), extend: bool = False, ) -> type: """Create a Strawberry type from a list of StrawberryFields. diff --git a/strawberry/types/arguments.py b/strawberry/types/arguments.py index f781d20ea2..06bc1b5d4e 100644 --- a/strawberry/types/arguments.py +++ b/strawberry/types/arguments.py @@ -6,8 +6,6 @@ TYPE_CHECKING, Annotated, Any, - Optional, - Union, cast, get_args, get_origin, @@ -48,19 +46,19 @@ class StrawberryArgumentAnnotation: - description: Optional[str] - name: Optional[str] - deprecation_reason: Optional[str] + description: str | None + name: str | None + deprecation_reason: str | None directives: Iterable[object] metadata: Mapping[Any, Any] def __init__( self, - description: Optional[str] = None, - name: Optional[str] = None, - deprecation_reason: Optional[str] = None, + description: str | None = None, + name: str | None = None, + deprecation_reason: str | None = None, directives: Iterable[object] = (), - metadata: Optional[Mapping[Any, Any]] = None, + metadata: Mapping[Any, Any] | None = None, ) -> None: self.description = description self.name = name @@ -73,14 +71,14 @@ class StrawberryArgument: def __init__( self, python_name: str, - graphql_name: Optional[str], + graphql_name: str | None, type_annotation: StrawberryAnnotation, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, default: object = _deprecated_UNSET, - deprecation_reason: Optional[str] = None, + deprecation_reason: str | None = None, directives: Iterable[object] = (), - metadata: Optional[Mapping[Any, Any]] = None, + metadata: Mapping[Any, Any] | None = None, ) -> None: self.python_name = python_name self.graphql_name = graphql_name @@ -131,7 +129,7 @@ def __init__( ) @property - def type(self) -> Union[StrawberryType, type]: + def type(self) -> StrawberryType | type: return self.type_annotation.resolve() @property @@ -146,8 +144,8 @@ def is_maybe(self) -> bool: def _is_leaf_type( - type_: Union[StrawberryType, type], - scalar_registry: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + type_: StrawberryType | type, + scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], skip_classes: tuple[type, ...] = (), ) -> bool: if type_ in skip_classes: @@ -170,8 +168,8 @@ def _is_leaf_type( def _is_optional_leaf_type( - type_: Union[StrawberryType, type], - scalar_registry: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + type_: StrawberryType | type, + scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], skip_classes: tuple[type, ...] = (), ) -> bool: if type_ in skip_classes: @@ -185,8 +183,8 @@ def _is_optional_leaf_type( def convert_argument( value: object, - type_: Union[StrawberryType, type], - scalar_registry: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + type_: StrawberryType | type, + scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], config: StrawberryConfig, ) -> object: from strawberry.relay.types import GlobalID @@ -272,7 +270,7 @@ def convert_argument( def convert_arguments( value: dict[str, Any], arguments: list[StrawberryArgument], - scalar_registry: Mapping[object, Union[ScalarWrapper, ScalarDefinition]], + scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], config: StrawberryConfig, ) -> dict[str, Any]: """Converts a nested dictionary to a dictionary of actual types. @@ -304,11 +302,11 @@ def convert_arguments( def argument( - description: Optional[str] = None, - name: Optional[str] = None, - deprecation_reason: Optional[str] = None, + description: str | None = None, + name: str | None = None, + deprecation_reason: str | None = None, directives: Iterable[object] = (), - metadata: Optional[Mapping[Any, Any]] = None, + metadata: Mapping[Any, Any] | None = None, ) -> StrawberryArgumentAnnotation: """Function to add metadata to an argument, like a description or deprecation reason. diff --git a/strawberry/types/auto.py b/strawberry/types/auto.py index 2906a4667b..0a5a4f192a 100644 --- a/strawberry/types/auto.py +++ b/strawberry/types/auto.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin +from typing import Annotated, Any, cast, get_args, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.types.base import StrawberryType @@ -23,7 +23,7 @@ class StrawberryAutoMeta(type): """ def __init__(cls, *args: str, **kwargs: Any) -> None: - cls._instance: Optional[StrawberryAuto] = None + cls._instance: StrawberryAuto | None = None super().__init__(*args, **kwargs) def __call__(cls, *args: str, **kwargs: Any) -> Any: @@ -34,7 +34,7 @@ def __call__(cls, *args: str, **kwargs: Any) -> Any: def __instancecheck__( cls, - instance: Union[StrawberryAuto, StrawberryAnnotation, StrawberryType, type], + instance: StrawberryAuto | StrawberryAnnotation | StrawberryType | type, ) -> bool: if isinstance(instance, StrawberryAnnotation): resolved = instance.raw_annotation diff --git a/strawberry/types/base.py b/strawberry/types/base.py index 207a4a319e..d603b6d08c 100644 --- a/strawberry/types/base.py +++ b/strawberry/types/base.py @@ -7,9 +7,7 @@ Any, ClassVar, Literal, - Optional, TypeVar, - Union, overload, ) from typing_extensions import Protocol, Self, deprecated @@ -51,9 +49,9 @@ def is_one_of(self) -> bool: def copy_with( self, type_var_map: Mapping[ - str, Union[StrawberryType, type[WithStrawberryObjectDefinition]] + str, StrawberryType | type[WithStrawberryObjectDefinition] ], - ) -> Union[StrawberryType, type[WithStrawberryObjectDefinition]]: + ) -> StrawberryType | type[WithStrawberryObjectDefinition]: raise NotImplementedError @property @@ -87,7 +85,7 @@ def __hash__(self) -> int: class StrawberryContainer(StrawberryType): def __init__( - self, of_type: Union[StrawberryType, type[WithStrawberryObjectDefinition], type] + self, of_type: StrawberryType | type[WithStrawberryObjectDefinition] | type ) -> None: self.of_type = of_type @@ -117,7 +115,7 @@ def type_params(self) -> list[TypeVar]: def copy_with( self, type_var_map: Mapping[ - str, Union[StrawberryType, type[WithStrawberryObjectDefinition]] + str, StrawberryType | type[WithStrawberryObjectDefinition] ], ) -> Self: of_type_copy = self.of_type @@ -155,7 +153,7 @@ class StrawberryList(StrawberryContainer): ... class StrawberryOptional(StrawberryContainer): def __init__( self, - of_type: Union[StrawberryType, type[WithStrawberryObjectDefinition], type], + of_type: StrawberryType | type[WithStrawberryObjectDefinition] | type, ) -> None: super().__init__(of_type) @@ -169,8 +167,8 @@ def __init__(self, type_var: TypeVar) -> None: self.type_var = type_var def copy_with( - self, type_var_map: Mapping[str, Union[StrawberryType, type]] - ) -> Union[StrawberryType, type]: + self, type_var_map: Mapping[str, StrawberryType | type] + ) -> StrawberryType | type: return type_var_map[self.type_var.__name__] @property @@ -229,14 +227,14 @@ def get_object_definition( obj: Any, *, strict: bool = False, -) -> Optional[StrawberryObjectDefinition]: ... +) -> StrawberryObjectDefinition | None: ... def get_object_definition( obj: Any, *, strict: bool = False, -) -> Optional[StrawberryObjectDefinition]: +) -> StrawberryObjectDefinition | None: definition = obj.__strawberry_definition__ if has_object_definition(obj) else None if strict and definition is None: raise TypeError(f"{obj!r} does not have a StrawberryObjectDefinition") @@ -256,20 +254,18 @@ class StrawberryObjectDefinition(StrawberryType): is_input: bool is_interface: bool origin: type[Any] - description: Optional[str] + description: str | None interfaces: list[StrawberryObjectDefinition] extend: bool - directives: Optional[Sequence[object]] - is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]] - resolve_type: Optional[ - Callable[[Any, GraphQLResolveInfo, GraphQLAbstractType], str] - ] + directives: Sequence[object] | None + is_type_of: Callable[[Any, GraphQLResolveInfo], bool] | None + resolve_type: Callable[[Any, GraphQLResolveInfo, GraphQLAbstractType], str] | None fields: list[StrawberryField] - concrete_of: Optional[StrawberryObjectDefinition] = None + concrete_of: StrawberryObjectDefinition | None = None """Concrete implementations of Generic TypeDefinitions fill this in""" - type_var_map: Mapping[str, Union[StrawberryType, type]] = dataclasses.field( + type_var_map: Mapping[str, StrawberryType | type] = dataclasses.field( default_factory=dict ) @@ -298,7 +294,7 @@ def resolve_generic(self, wrapped_cls: type) -> type: return self.copy_with(type_var_map) def copy_with( - self, type_var_map: Mapping[str, Union[StrawberryType, type]] + self, type_var_map: Mapping[str, StrawberryType | type] ) -> type[WithStrawberryObjectDefinition]: fields = [field.copy_with(type_var_map) for field in self.fields] @@ -334,7 +330,7 @@ def copy_with( return new_type - def get_field(self, python_name: str) -> Optional[StrawberryField]: + def get_field(self, python_name: str) -> StrawberryField | None: return next( (field for field in self.fields if field.python_name == python_name), None ) @@ -356,7 +352,7 @@ def is_specialized_generic(self) -> bool: ) @property - def specialized_type_var_map(self) -> Optional[dict[str, type]]: + def specialized_type_var_map(self) -> dict[str, type] | None: return get_specialized_type_var_map(self.origin) @property diff --git a/strawberry/types/enum.py b/strawberry/types/enum.py index a99a433c93..17f021b4ca 100644 --- a/strawberry/types/enum.py +++ b/strawberry/types/enum.py @@ -3,9 +3,7 @@ from enum import EnumMeta from typing import ( Any, - Optional, TypeVar, - Union, overload, ) @@ -17,9 +15,9 @@ class EnumValue: name: str value: Any - deprecation_reason: Optional[str] = None + deprecation_reason: str | None = None directives: Iterable[object] = () - description: Optional[str] = None + description: str | None = None @dataclasses.dataclass @@ -27,7 +25,7 @@ class EnumDefinition(StrawberryType): wrapped_cls: EnumMeta name: str values: list[EnumValue] - description: Optional[str] + description: str | None directives: Iterable[object] = () def __hash__(self) -> int: @@ -35,8 +33,8 @@ def __hash__(self) -> int: return hash(self.name) def copy_with( - self, type_var_map: Mapping[str, Union[StrawberryType, type]] - ) -> Union[StrawberryType, type]: + self, type_var_map: Mapping[str, StrawberryType | type] + ) -> StrawberryType | type: # enum don't support type parameters, so we can safely return self return self @@ -53,10 +51,10 @@ def origin(self) -> type: @dataclasses.dataclass class EnumValueDefinition: value: Any - graphql_name: Optional[str] = None - deprecation_reason: Optional[str] = None + graphql_name: str | None = None + deprecation_reason: str | None = None directives: Iterable[object] = () - description: Optional[str] = None + description: str | None = None def __int__(self) -> int: return self.value @@ -64,10 +62,10 @@ def __int__(self) -> int: def enum_value( value: Any, - name: Optional[str] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + deprecation_reason: str | None = None, directives: Iterable[object] = (), - description: Optional[str] = None, + description: str | None = None, ) -> EnumValueDefinition: """Function to customise an enum value, for example to add a description or deprecation reason. @@ -108,8 +106,8 @@ class MyEnum(Enum): def _process_enum( cls: EnumType, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), ) -> EnumType: if not isinstance(cls, EnumMeta): @@ -165,8 +163,8 @@ def _process_enum( def enum( cls: EnumType, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), ) -> EnumType: ... @@ -175,19 +173,19 @@ def enum( def enum( cls: None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), ) -> Callable[[EnumType], EnumType]: ... def enum( - cls: Optional[EnumType] = None, + cls: EnumType | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, directives: Iterable[object] = (), -) -> Union[EnumType, Callable[[EnumType], EnumType]]: +) -> EnumType | Callable[[EnumType], EnumType]: """Annotates an Enum class a GraphQL enum. GraphQL enums only have names, while Python enums have names and values, diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index 2cd5c46392..aed95523ef 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -4,7 +4,6 @@ from typing import ( TYPE_CHECKING, Any, - Optional, runtime_checkable, ) from typing_extensions import Protocol, TypedDict, deprecated @@ -29,36 +28,36 @@ @dataclasses.dataclass class ExecutionContext: - query: Optional[str] + query: str | None schema: Schema allowed_operations: Iterable[OperationType] context: Any = None - variables: Optional[dict[str, Any]] = None + variables: dict[str, Any] | None = None parse_options: ParseOptions = dataclasses.field( default_factory=lambda: ParseOptions() ) - root_value: Optional[Any] = None + root_value: Any | None = None validation_rules: tuple[type[ASTValidationRule], ...] = dataclasses.field( default_factory=lambda: tuple(specified_rules) ) # The operation name that is provided by the request - provided_operation_name: dataclasses.InitVar[Optional[str]] = None + provided_operation_name: dataclasses.InitVar[str | None] = None # Values that get populated during the GraphQL execution so that they can be # accessed by extensions - graphql_document: Optional[DocumentNode] = None - pre_execution_errors: Optional[list[GraphQLError]] = None - result: Optional[GraphQLExecutionResult] = None + graphql_document: DocumentNode | None = None + pre_execution_errors: list[GraphQLError] | None = None + result: GraphQLExecutionResult | None = None extensions_results: dict[str, Any] = dataclasses.field(default_factory=dict) - operation_extensions: Optional[dict[str, Any]] = None + operation_extensions: dict[str, Any] | None = None def __post_init__(self, provided_operation_name: str | None) -> None: self._provided_operation_name = provided_operation_name @property - def operation_name(self) -> Optional[str]: + def operation_name(self) -> str | None: if self._provided_operation_name is not None: return self._provided_operation_name @@ -79,7 +78,7 @@ def operation_type(self) -> OperationType: return get_operation_type(graphql_document, self.operation_name) - def _get_first_operation(self) -> Optional[OperationDefinitionNode]: + def _get_first_operation(self) -> OperationDefinitionNode | None: graphql_document = self.graphql_document if not graphql_document: return None @@ -88,16 +87,16 @@ def _get_first_operation(self) -> Optional[OperationDefinitionNode]: @property @deprecated("Use 'pre_execution_errors' instead") - def errors(self) -> Optional[list[GraphQLError]]: + def errors(self) -> list[GraphQLError] | None: """Deprecated: Use pre_execution_errors instead.""" return self.pre_execution_errors @dataclasses.dataclass class ExecutionResult: - data: Optional[dict[str, Any]] - errors: Optional[list[GraphQLError]] - extensions: Optional[dict[str, Any]] = None + data: dict[str, Any] | None + errors: list[GraphQLError] | None + extensions: dict[str, Any] | None = None @dataclasses.dataclass diff --git a/strawberry/types/field.py b/strawberry/types/field.py index 5ce23534b3..ede4beecd5 100644 --- a/strawberry/types/field.py +++ b/strawberry/types/field.py @@ -9,7 +9,8 @@ from typing import ( TYPE_CHECKING, Any, - Optional, + NoReturn, + TypeAlias, TypeVar, Union, overload, @@ -39,24 +40,31 @@ T = TypeVar("T") -_RESOLVER_TYPE_SYNC = Union[ +_RESOLVER_TYPE_SYNC: TypeAlias = Union[ StrawberryResolver[T], Callable[..., T], "staticmethod[Any, T]", "classmethod[Any, Any, T]", ] -_RESOLVER_TYPE_ASYNC = Union[ - Callable[..., Coroutine[Any, Any, T]], - Callable[..., Awaitable[T]], -] +_RESOLVER_TYPE_ASYNC: TypeAlias = ( + Callable[..., Coroutine[Any, Any, T]] | Callable[..., Awaitable[T]] +) + +_RESOLVER_TYPE: TypeAlias = _RESOLVER_TYPE_SYNC[T] | _RESOLVER_TYPE_ASYNC[T] + -_RESOLVER_TYPE = Union[_RESOLVER_TYPE_SYNC[T], _RESOLVER_TYPE_ASYNC[T]] +class UNRESOLVED: + def __new__(cls) -> NoReturn: + raise TypeError("UNRESOLVED is a sentinel and cannot be instantiated.") -UNRESOLVED = object() + +FieldType: TypeAlias = ( + StrawberryType | type[WithStrawberryObjectDefinition | UNRESOLVED] +) -def _is_generic(resolver_type: Union[StrawberryType, type]) -> bool: +def _is_generic(resolver_type: StrawberryType | type) -> bool: """Returns True if `resolver_type` is generic else False.""" if isinstance(resolver_type, StrawberryType): return resolver_type.is_graphql_generic @@ -69,23 +77,23 @@ def _is_generic(resolver_type: Union[StrawberryType, type]) -> bool: class StrawberryField(dataclasses.Field): - type_annotation: Optional[StrawberryAnnotation] + type_annotation: StrawberryAnnotation | None default_resolver: Callable[[Any, str], object] = getattr def __init__( self, - python_name: Optional[str] = None, - graphql_name: Optional[str] = None, - type_annotation: Optional[StrawberryAnnotation] = None, - origin: Optional[Union[type, Callable, staticmethod, classmethod]] = None, + python_name: str | None = None, + graphql_name: str | None = None, + type_annotation: StrawberryAnnotation | None = None, + origin: type | Callable | staticmethod | classmethod | None = None, is_subscription: bool = False, - description: Optional[str] = None, - base_resolver: Optional[StrawberryResolver] = None, + description: str | None = None, + base_resolver: StrawberryResolver | None = None, permission_classes: list[type[BasePermission]] = (), # type: ignore default: object = dataclasses.MISSING, - default_factory: Union[Callable[[], Any], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - deprecation_reason: Optional[str] = None, + default_factory: Callable[[], Any] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + deprecation_reason: str | None = None, directives: Sequence[object] = (), extensions: list[FieldExtension] = (), # type: ignore ) -> None: @@ -117,11 +125,11 @@ def __init__( self.type_annotation = type_annotation - self.description: Optional[str] = description + self.description: str | None = description self.origin = origin - self._arguments: Optional[list[StrawberryArgument]] = None - self._base_resolver: Optional[StrawberryResolver] = None + self._arguments: list[StrawberryArgument] | None = None + self._base_resolver: StrawberryResolver | None = None if base_resolver is not None: self.base_resolver = base_resolver @@ -213,8 +221,8 @@ def __call__(self, resolver: _RESOLVER_TYPE) -> Self: return self def get_result( - self, source: Any, info: Optional[Info], args: list[Any], kwargs: Any - ) -> Union[Awaitable[Any], Any]: + self, source: Any, info: Info | None, args: list[Any], kwargs: Any + ) -> Awaitable[Any] | Any: """Calls the resolver defined for the StrawberryField. If the field doesn't have a resolver defined we default @@ -256,7 +264,7 @@ def is_graphql_generic(self) -> bool: else _is_generic(self.type) ) - def _python_name(self) -> Optional[str]: + def _python_name(self) -> str | None: if self.name: return self.name @@ -271,7 +279,7 @@ def _set_python_name(self, name: str) -> None: python_name: str = property(_python_name, _set_python_name) # type: ignore[assignment] @property - def base_resolver(self) -> Optional[StrawberryResolver]: + def base_resolver(self) -> StrawberryResolver | None: return self._base_resolver @base_resolver.setter @@ -295,13 +303,7 @@ def base_resolver(self, resolver: StrawberryResolver) -> None: _ = resolver.arguments @property - def type( - self, - ) -> Union[ # type: ignore [valid-type] - StrawberryType, - type[WithStrawberryObjectDefinition], - Literal[UNRESOLVED], - ]: + def type(self) -> FieldType: return self.resolve_type() @type.setter @@ -331,15 +333,11 @@ def type_params(self) -> list[TypeVar]: def resolve_type( self, *, - type_definition: Optional[StrawberryObjectDefinition] = None, - ) -> Union[ # type: ignore [valid-type] - StrawberryType, - type[WithStrawberryObjectDefinition], - Literal[UNRESOLVED], - ]: + type_definition: StrawberryObjectDefinition | None = None, + ) -> FieldType: # We return UNRESOLVED by default, which means this case will raise a # MissingReturnAnnotationError exception in _check_field_annotations - resolved = UNRESOLVED + resolved: FieldType = UNRESOLVED # type: ignore[assignment] # We are catching NameError because dataclasses tries to fetch the type # of the field from the class before the class is fully defined. @@ -360,13 +358,13 @@ def resolve_type( return resolved def copy_with( - self, type_var_map: Mapping[str, Union[StrawberryType, builtins.type]] + self, type_var_map: Mapping[str, StrawberryType | builtins.type] ) -> Self: new_field = copy.copy(self) - override_type: Optional[ - Union[StrawberryType, type[WithStrawberryObjectDefinition]] - ] = None + override_type: StrawberryType | type[WithStrawberryObjectDefinition] | None = ( + None + ) type_ = self.resolve_type() if has_object_definition(type_): type_definition = type_.__strawberry_definition__ @@ -407,18 +405,18 @@ def is_async(self) -> bool: def field( *, resolver: _RESOLVER_TYPE_ASYNC[T], - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @@ -426,36 +424,36 @@ def field( def field( *, resolver: _RESOLVER_TYPE_SYNC[T], - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @overload def field( *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, + description: str | None = None, init: Literal[True] = True, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> Any: ... @@ -463,17 +461,17 @@ def field( def field( resolver: _RESOLVER_TYPE_ASYNC[T], *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... @@ -481,34 +479,34 @@ def field( def field( resolver: _RESOLVER_TYPE_SYNC[T], *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... def field( - resolver: Optional[_RESOLVER_TYPE[Any]] = None, + resolver: _RESOLVER_TYPE[Any] | None = None, *, - name: Optional[str] = None, + name: str | None = None, is_subscription: bool = False, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, # This init parameter is used by PyRight to determine whether this field # is added in the constructor or not. It is not used to change # any behavior at the moment. diff --git a/strawberry/types/fields/resolver.py b/strawberry/types/fields/resolver.py index cde55b0808..095a123d7d 100644 --- a/strawberry/types/fields/resolver.py +++ b/strawberry/types/fields/resolver.py @@ -12,9 +12,7 @@ Any, Generic, NamedTuple, - Optional, TypeVar, - Union, cast, get_origin, ) @@ -63,7 +61,7 @@ def find( self, parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], - ) -> Optional[inspect.Parameter]: + ) -> inspect.Parameter | None: """Finds the reserved parameter from ``parameters``.""" @@ -74,7 +72,7 @@ def find( self, parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], - ) -> Optional[inspect.Parameter]: + ) -> inspect.Parameter | None: del resolver return next((p for p in parameters if p.name == self.name), None) @@ -86,7 +84,7 @@ def find( self, parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], - ) -> Optional[inspect.Parameter]: + ) -> inspect.Parameter | None: del resolver if parameters: # Add compatibility for resolvers with no arguments first_parameter = parameters[0] @@ -108,7 +106,7 @@ def find( self, parameters: tuple[inspect.Parameter, ...], resolver: StrawberryResolver[Any], - ) -> Optional[inspect.Parameter]: + ) -> inspect.Parameter | None: # Go through all the types even after we've found one so we can # give a helpful error message if someone uses the type more than once. type_parameters = [] @@ -193,10 +191,10 @@ class StrawberryResolver(Generic[T]): def __init__( self, - func: Union[Callable[..., T], staticmethod, classmethod], + func: Callable[..., T] | staticmethod | classmethod, *, - description: Optional[str] = None, - type_override: Optional[Union[StrawberryType, type]] = None, + description: str | None = None, + type_override: StrawberryType | type | None = None, ) -> None: self.wrapped_func = func self._description = description @@ -220,7 +218,7 @@ def signature(self) -> inspect.Signature: @cached_property def strawberry_annotations( self, - ) -> dict[inspect.Parameter, Union[StrawberryAnnotation, None]]: + ) -> dict[inspect.Parameter, StrawberryAnnotation | None]: return { p: ( StrawberryAnnotation(p.annotation, namespace=self._namespace) @@ -233,7 +231,7 @@ def strawberry_annotations( @cached_property def reserved_parameters( self, - ) -> dict[ReservedParameterSpecification, Optional[inspect.Parameter]]: + ) -> dict[ReservedParameterSpecification, inspect.Parameter | None]: """Mapping of reserved parameter specification to parameter.""" parameters = tuple(self.signature.parameters.values()) return {spec: spec.find(parameters, self) for spec in self.RESERVED_PARAMSPEC} @@ -281,19 +279,19 @@ def arguments(self) -> list[StrawberryArgument]: return arguments @cached_property - def info_parameter(self) -> Optional[inspect.Parameter]: + def info_parameter(self) -> inspect.Parameter | None: return self.reserved_parameters.get(INFO_PARAMSPEC) @cached_property - def root_parameter(self) -> Optional[inspect.Parameter]: + def root_parameter(self) -> inspect.Parameter | None: return self.reserved_parameters.get(ROOT_PARAMSPEC) @cached_property - def self_parameter(self) -> Optional[inspect.Parameter]: + def self_parameter(self) -> inspect.Parameter | None: return self.reserved_parameters.get(SELF_PARAMSPEC) @cached_property - def parent_parameter(self) -> Optional[inspect.Parameter]: + def parent_parameter(self) -> inspect.Parameter | None: return self.reserved_parameters.get(PARENT_PARAMSPEC) @cached_property @@ -320,7 +318,7 @@ def annotations(self) -> dict[str, object]: } @cached_property - def type_annotation(self) -> Optional[StrawberryAnnotation]: + def type_annotation(self) -> StrawberryAnnotation | None: return_annotation = self.signature.return_annotation if return_annotation is inspect.Signature.empty: return None @@ -329,7 +327,7 @@ def type_annotation(self) -> Optional[StrawberryAnnotation]: ) @property - def type(self) -> Optional[Union[StrawberryType, type]]: + def type(self) -> StrawberryType | type | None: if self._type_override: return self._type_override if self.type_annotation is None: @@ -355,7 +353,7 @@ def is_async(self) -> bool: ) def copy_with( - self, type_var_map: Mapping[str, Union[StrawberryType, builtins.type]] + self, type_var_map: Mapping[str, StrawberryType | builtins.type] ) -> StrawberryResolver: type_override = None diff --git a/strawberry/types/info.py b/strawberry/types/info.py index e24e01d2bc..fb6175b65b 100644 --- a/strawberry/types/info.py +++ b/strawberry/types/info.py @@ -7,8 +7,6 @@ TYPE_CHECKING, Any, Generic, - Optional, - Union, ) from typing_extensions import TypeVar @@ -21,11 +19,7 @@ from strawberry.schema import Schema from strawberry.types.arguments import StrawberryArgument - from strawberry.types.base import ( - StrawberryType, - WithStrawberryObjectDefinition, - ) - from strawberry.types.field import StrawberryField + from strawberry.types.field import FieldType, StrawberryField from .nodes import Selection @@ -68,7 +62,7 @@ def hello(self, info: strawberry.Info[str, str]) -> str: _raw_info: GraphQLResolveInfo _field: StrawberryField - def __class_getitem__(cls, types: Union[type, tuple[type, ...]]) -> type[Info]: + def __class_getitem__(cls, types: type | tuple[type, ...]) -> type[Info]: """Workaround for when passing only one type. Python doesn't yet support directly passing only one type to a generic class @@ -131,7 +125,7 @@ def variable_values(self) -> dict[str, Any]: @property def return_type( self, - ) -> Optional[Union[type[WithStrawberryObjectDefinition], StrawberryType]]: + ) -> FieldType: """The return type of the current field being resolved.""" return self._field.type @@ -154,7 +148,7 @@ def path(self) -> Path: # TODO: parent_type as strawberry types # Helper functions - def get_argument_definition(self, name: str) -> Optional[StrawberryArgument]: + def get_argument_definition(self, name: str) -> StrawberryArgument | None: """Get the StrawberryArgument definition for the current field by name.""" try: return next(arg for arg in self._field.arguments if arg.python_name == name) diff --git a/strawberry/types/lazy_type.py b/strawberry/types/lazy_type.py index 5d0d271e20..502d132d91 100644 --- a/strawberry/types/lazy_type.py +++ b/strawberry/types/lazy_type.py @@ -8,7 +8,6 @@ Any, ForwardRef, Generic, - Optional, TypeVar, Union, cast, @@ -32,7 +31,7 @@ class LazyType(Generic[TypeName, Module]): type_name: str module: str - package: Optional[str] = None + package: str | None = None def __class_getitem__(cls, params: tuple[str, str]) -> "Self": warnings.warn( @@ -57,7 +56,7 @@ def __class_getitem__(cls, params: tuple[str, str]) -> "Self": return cls(type_name, module, package) def __or__(self, other: Other) -> object: - return Union[self, other] + return Union[self, other] # noqa: UP007 def resolve_type(self) -> type[Any]: module = importlib.import_module(self.module, self.package) diff --git a/strawberry/types/maybe.py b/strawberry/types/maybe.py index becafbeb45..6d2e431fed 100644 --- a/strawberry/types/maybe.py +++ b/strawberry/types/maybe.py @@ -1,5 +1,5 @@ import typing -from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar T = TypeVar("T") @@ -30,7 +30,7 @@ def __bool__(self) -> bool: if TYPE_CHECKING: - Maybe: TypeAlias = Union[Some[T], None] + Maybe: TypeAlias = Some[T] | None else: # we do this trick so we can inspect that at runtime class Maybe(Generic[T]): ... diff --git a/strawberry/types/mutation.py b/strawberry/types/mutation.py index c4a7ce196e..b8d6a3b0b6 100644 --- a/strawberry/types/mutation.py +++ b/strawberry/types/mutation.py @@ -5,8 +5,6 @@ TYPE_CHECKING, Any, Literal, - Optional, - Union, overload, ) @@ -34,17 +32,17 @@ def mutation( *, resolver: _RESOLVER_TYPE_ASYNC[T], - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @@ -52,34 +50,34 @@ def mutation( def mutation( *, resolver: _RESOLVER_TYPE_SYNC[T], - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @overload def mutation( *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, init: Literal[True] = True, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> Any: ... @@ -87,16 +85,16 @@ def mutation( def mutation( resolver: _RESOLVER_TYPE_ASYNC[T], *, - name: Optional[str] = None, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... @@ -104,32 +102,32 @@ def mutation( def mutation( resolver: _RESOLVER_TYPE_SYNC[T], *, - name: Optional[str] = None, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... def mutation( - resolver: Optional[_RESOLVER_TYPE[Any]] = None, + resolver: _RESOLVER_TYPE[Any] | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, # This init parameter is used by PyRight to determine whether this field # is added in the constructor or not. It is not used to change # any behavior at the moment. @@ -194,17 +192,17 @@ def create_post(self, title: str, content: str) -> Post: ... def subscription( *, resolver: _RESOLVER_TYPE_ASYNC[T], - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @@ -212,34 +210,34 @@ def subscription( def subscription( *, resolver: _RESOLVER_TYPE_SYNC[T], - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, init: Literal[False] = False, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> T: ... @overload def subscription( *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, init: Literal[True] = True, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> Any: ... @@ -247,16 +245,16 @@ def subscription( def subscription( resolver: _RESOLVER_TYPE_ASYNC[T], *, - name: Optional[str] = None, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... @@ -264,32 +262,32 @@ def subscription( def subscription( resolver: _RESOLVER_TYPE_SYNC[T], *, - name: Optional[str] = None, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, ) -> StrawberryField: ... def subscription( - resolver: Optional[_RESOLVER_TYPE[Any]] = None, + resolver: _RESOLVER_TYPE[Any] | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, - permission_classes: Optional[list[type[BasePermission]]] = None, - deprecation_reason: Optional[str] = None, + name: str | None = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, default: Any = dataclasses.MISSING, - default_factory: Union[Callable[..., object], object] = dataclasses.MISSING, - metadata: Optional[Mapping[Any, Any]] = None, - directives: Optional[Sequence[object]] = (), - extensions: Optional[list[FieldExtension]] = None, - graphql_type: Optional[Any] = None, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] | None = (), + extensions: list[FieldExtension] | None = None, + graphql_type: Any | None = None, init: Literal[True, False, None] = None, ) -> Any: """Annotates a method or property as a GraphQL subscription. diff --git a/strawberry/types/nodes.py b/strawberry/types/nodes.py index 80da9092fd..c10b9cf50e 100644 --- a/strawberry/types/nodes.py +++ b/strawberry/types/nodes.py @@ -12,7 +12,7 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from graphql.language import FieldNode as GQLFieldNode from graphql.language import FragmentSpreadNode as GQLFragmentSpreadNode @@ -139,7 +139,7 @@ class SelectedField: directives: Directives arguments: Arguments selections: list[Selection] - alias: Optional[str] = None + alias: str | None = None @classmethod def from_node(cls, info: GraphQLResolveInfo, node: GQLFieldNode) -> SelectedField: diff --git a/strawberry/types/object_type.py b/strawberry/types/object_type.py index ec8c150731..3859a23270 100644 --- a/strawberry/types/object_type.py +++ b/strawberry/types/object_type.py @@ -5,9 +5,7 @@ from collections.abc import Callable, Sequence from typing import ( Any, - Optional, TypeVar, - Union, overload, ) from typing_extensions import dataclass_transform, get_annotations @@ -118,13 +116,13 @@ def _inject_default_for_maybe_annotations( def _process_type( cls: T, *, - name: Optional[str] = None, + name: str | None = None, is_input: bool = False, is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), extend: bool = False, - original_type_annotations: Optional[dict[str, Any]] = None, + original_type_annotations: dict[str, Any] | None = None, ) -> T: name = name or to_camel_case(cls.__name__) original_type_annotations = original_type_annotations or {} @@ -190,11 +188,11 @@ def _process_type( def type( cls: T, *, - name: Optional[str] = None, + name: str | None = None, is_input: bool = False, is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), extend: bool = False, ) -> T: ... @@ -205,25 +203,25 @@ def type( ) def type( *, - name: Optional[str] = None, + name: str | None = None, is_input: bool = False, is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), extend: bool = False, ) -> Callable[[T], T]: ... def type( - cls: Optional[T] = None, + cls: T | None = None, *, - name: Optional[str] = None, + name: str | None = None, is_input: bool = False, is_interface: bool = False, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + description: str | None = None, + directives: Sequence[object] | None = (), extend: bool = False, -) -> Union[T, Callable[[T], T]]: +) -> T | Callable[[T], T]: """Annotates a class as a GraphQL type. Similar to `dataclasses.dataclass`, but with additional functionality for @@ -312,10 +310,10 @@ def wrap(cls: T) -> T: def input( cls: T, *, - name: Optional[str] = None, - one_of: Optional[bool] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + one_of: bool | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), ) -> T: ... @@ -325,20 +323,20 @@ def input( ) def input( *, - name: Optional[str] = None, - one_of: Optional[bool] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + one_of: bool | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), ) -> Callable[[T], T]: ... def input( - cls: Optional[T] = None, + cls: T | None = None, *, - name: Optional[str] = None, - one_of: Optional[bool] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + one_of: bool | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), ): """Annotates a class as a GraphQL Input type. @@ -391,9 +389,9 @@ class MyUserInput: def interface( cls: T, *, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), ) -> T: ... @@ -403,9 +401,9 @@ def interface( ) def interface( *, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), ) -> Callable[[T], T]: ... @@ -413,11 +411,11 @@ def interface( order_default=True, kw_only_default=True, field_specifiers=(field, StrawberryField) ) def interface( - cls: Optional[T] = None, + cls: T | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, - directives: Optional[Sequence[object]] = (), + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), ): """Annotates a class as a GraphQL Interface. diff --git a/strawberry/types/scalar.py b/strawberry/types/scalar.py index d1f83d17a4..43a9396074 100644 --- a/strawberry/types/scalar.py +++ b/strawberry/types/scalar.py @@ -8,7 +8,6 @@ NewType, Optional, TypeVar, - Union, overload, ) @@ -22,7 +21,7 @@ from graphql import GraphQLScalarType -_T = TypeVar("_T", bound=Union[type, NewType]) +_T = TypeVar("_T", bound=type | NewType) def identity(x: _T) -> _T: @@ -32,25 +31,25 @@ def identity(x: _T) -> _T: @dataclass class ScalarDefinition(StrawberryType): name: str - description: Optional[str] - specified_by_url: Optional[str] - serialize: Optional[Callable] - parse_value: Optional[Callable] - parse_literal: Optional[Callable] + description: str | None + specified_by_url: str | None + serialize: Callable | None + parse_value: Callable | None + parse_literal: Callable | None directives: Iterable[object] = () - origin: Optional[GraphQLScalarType | type] = None + origin: GraphQLScalarType | type | None = None # Optionally store the GraphQLScalarType instance so that we don't get # duplicates - implementation: Optional[GraphQLScalarType] = None + implementation: GraphQLScalarType | None = None # used for better error messages - _source_file: Optional[str] = None - _source_line: Optional[int] = None + _source_file: str | None = None + _source_line: int | None = None def copy_with( - self, type_var_map: Mapping[str, Union[StrawberryType, type]] - ) -> Union[StrawberryType, type]: + self, type_var_map: Mapping[str, StrawberryType | type] + ) -> StrawberryType | type: return super().copy_with(type_var_map) # type: ignore[safe-super] @property @@ -67,10 +66,10 @@ def __init__(self, wrap: Callable[[Any], Any]) -> None: def __call__(self, *args: str, **kwargs: Any) -> Any: return self.wrap(*args, **kwargs) - def __or__(self, other: Union[StrawberryType, type]) -> StrawberryType: + def __or__(self, other: StrawberryType | type) -> StrawberryType: if other is None: # Return the correct notation when using `StrawberryUnion | None`. - return Optional[self] + return Optional[self] # noqa: UP045 # Raise an error in any other case. # There is Work in progress to deal with more merging cases, see: @@ -81,12 +80,12 @@ def __or__(self, other: Union[StrawberryType, type]) -> StrawberryType: def _process_scalar( cls: _T, *, - name: Optional[str] = None, - description: Optional[str] = None, - specified_by_url: Optional[str] = None, - serialize: Optional[Callable] = None, - parse_value: Optional[Callable] = None, - parse_literal: Optional[Callable] = None, + name: str | None = None, + description: str | None = None, + specified_by_url: str | None = None, + serialize: Callable | None = None, + parse_value: Callable | None = None, + parse_literal: Callable | None = None, directives: Iterable[object] = (), ) -> ScalarWrapper: from strawberry.exceptions.handler import should_use_rich_exceptions @@ -122,12 +121,12 @@ def _process_scalar( @overload def scalar( *, - name: Optional[str] = None, - description: Optional[str] = None, - specified_by_url: Optional[str] = None, + name: str | None = None, + description: str | None = None, + specified_by_url: str | None = None, serialize: Callable = identity, - parse_value: Optional[Callable] = None, - parse_literal: Optional[Callable] = None, + parse_value: Callable | None = None, + parse_literal: Callable | None = None, directives: Iterable[object] = (), ) -> Callable[[_T], _T]: ... @@ -136,12 +135,12 @@ def scalar( def scalar( cls: _T, *, - name: Optional[str] = None, - description: Optional[str] = None, - specified_by_url: Optional[str] = None, + name: str | None = None, + description: str | None = None, + specified_by_url: str | None = None, serialize: Callable = identity, - parse_value: Optional[Callable] = None, - parse_literal: Optional[Callable] = None, + parse_value: Callable | None = None, + parse_literal: Callable | None = None, directives: Iterable[object] = (), ) -> _T: ... @@ -150,14 +149,14 @@ def scalar( # here or else it won't let us use any custom scalar to annotate attributes in # dataclasses/types. This should be properly solved when implementing StrawberryScalar def scalar( - cls: Optional[_T] = None, + cls: _T | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, - specified_by_url: Optional[str] = None, + name: str | None = None, + description: str | None = None, + specified_by_url: str | None = None, serialize: Callable = identity, - parse_value: Optional[Callable] = None, - parse_literal: Optional[Callable] = None, + parse_value: Callable | None = None, + parse_literal: Callable | None = None, directives: Iterable[object] = (), ) -> Any: """Annotates a class or type as a GraphQL custom scalar. diff --git a/strawberry/types/union.py b/strawberry/types/union.py index fbcf77ce52..184507a6c5 100644 --- a/strawberry/types/union.py +++ b/strawberry/types/union.py @@ -9,9 +9,7 @@ Annotated, Any, NoReturn, - Optional, TypeVar, - Union, cast, get_origin, ) @@ -48,14 +46,14 @@ class StrawberryUnion(StrawberryType): # used for better error messages - _source_file: Optional[str] = None - _source_line: Optional[int] = None + _source_file: str | None = None + _source_line: int | None = None def __init__( self, - name: Optional[str] = None, + name: str | None = None, type_annotations: tuple[StrawberryAnnotation, ...] = (), - description: Optional[str] = None, + description: str | None = None, directives: Iterable[object] = (), ) -> None: self.graphql_name = name @@ -64,7 +62,7 @@ def __init__( self.directives = directives self._source_file = None self._source_line = None - self.concrete_of: Optional[StrawberryUnion] = None + self.concrete_of: StrawberryUnion | None = None def __eq__(self, other: object) -> bool: if isinstance(other, StrawberryType): @@ -81,7 +79,7 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return hash((self.graphql_name, self.type_annotations, self.description)) - def __or__(self, other: Union[StrawberryType, type]) -> StrawberryType: + def __or__(self, other: StrawberryType | type) -> StrawberryType: # TODO: this will be removed in future versions, you should # use Annotated[Union[...], strawberry.union(...)] instead @@ -131,7 +129,7 @@ def _is_generic(type_: object) -> bool: return any(map(_is_generic, self.types)) def copy_with( - self, type_var_map: Mapping[str, Union[StrawberryType, type]] + self, type_var_map: Mapping[str, StrawberryType | type] ) -> StrawberryType: if not self.is_graphql_generic: return self @@ -139,7 +137,7 @@ def copy_with( new_types = [] for type_ in self.types: - new_type: Union[StrawberryType, type] + new_type: StrawberryType | type if has_object_definition(type_): type_definition = type_.__strawberry_definition__ @@ -189,7 +187,7 @@ def _resolve_union_type( # Couldn't resolve using `is_type_of` raise WrongReturnTypeForUnion(info.field_name, str(type(root))) - return_type: Optional[GraphQLType] + return_type: GraphQLType | None # Iterate over all of our known types and find the first concrete # type that implements the type. We prioritise checking types named in the @@ -242,9 +240,9 @@ def is_valid_union_type(type_: object) -> bool: def union( name: str, - types: Optional[Collection[type[Any]]] = None, + types: Collection[type[Any]] | None = None, *, - description: Optional[str] = None, + description: str | None = None, directives: Iterable[object] = (), ) -> StrawberryUnion: """Creates a new named Union type. diff --git a/strawberry/utils/aio.py b/strawberry/utils/aio.py index 6ac5551f76..0d0c7ce4da 100644 --- a/strawberry/utils/aio.py +++ b/strawberry/utils/aio.py @@ -9,9 +9,7 @@ from contextlib import asynccontextmanager, suppress from typing import ( Any, - Optional, TypeVar, - Union, cast, ) @@ -35,7 +33,7 @@ async def aclosing(thing: _T) -> AsyncGenerator[_T, None]: async def aenumerate( - iterable: Union[AsyncIterator[_T], AsyncIterable[_T]], + iterable: AsyncIterator[_T] | AsyncIterable[_T], ) -> AsyncIterator[tuple[int, _T]]: """Async version of enumerate.""" i = 0 @@ -45,10 +43,10 @@ async def aenumerate( async def aislice( - aiterable: Union[AsyncIterator[_T], AsyncIterable[_T]], - start: Optional[int] = None, - stop: Optional[int] = None, - step: Optional[int] = None, + aiterable: AsyncIterator[_T] | AsyncIterable[_T], + start: int | None = None, + stop: int | None = None, + step: int | None = None, ) -> AsyncIterator[_T]: """Async version of itertools.islice.""" # This is based on diff --git a/strawberry/utils/await_maybe.py b/strawberry/utils/await_maybe.py index 6833d26d07..454d7e9720 100644 --- a/strawberry/utils/await_maybe.py +++ b/strawberry/utils/await_maybe.py @@ -1,11 +1,11 @@ import inspect from collections.abc import AsyncIterator, Awaitable, Iterator -from typing import TypeVar, Union +from typing import TypeAlias, TypeVar T = TypeVar("T") -AwaitableOrValue = Union[Awaitable[T], T] -AsyncIteratorOrIterator = Union[AsyncIterator[T], Iterator[T]] +AwaitableOrValue: TypeAlias = Awaitable[T] | T +AsyncIteratorOrIterator: TypeAlias = AsyncIterator[T] | Iterator[T] async def await_maybe(value: AwaitableOrValue[T]) -> T: diff --git a/strawberry/utils/deprecations.py b/strawberry/utils/deprecations.py index 51f85c317d..eee08676e3 100644 --- a/strawberry/utils/deprecations.py +++ b/strawberry/utils/deprecations.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Optional +from typing import Any class DEPRECATION_MESSAGES: # noqa: N801 @@ -19,7 +19,7 @@ def __init__(self, msg: str, alias: object, attr_name: str) -> None: def warn(self) -> None: warnings.warn(self.msg, stacklevel=2) - def __get__(self, obj: Optional[object], type: Optional[type] = None) -> Any: + def __get__(self, obj: object | None, type: type | None = None) -> Any: self.warn() return self.alias diff --git a/strawberry/utils/importer.py b/strawberry/utils/importer.py index 7179245c61..c7149838c2 100644 --- a/strawberry/utils/importer.py +++ b/strawberry/utils/importer.py @@ -1,9 +1,8 @@ import importlib -from typing import Optional def import_module_symbol( - selector: str, default_symbol_name: Optional[str] = None + selector: str, default_symbol_name: str | None = None ) -> object: if ":" in selector: module_name, symbol_name = selector.split(":", 1) diff --git a/strawberry/utils/inspect.py b/strawberry/utils/inspect.py index ed6a6cbf4c..ef9ccaa5bf 100644 --- a/strawberry/utils/inspect.py +++ b/strawberry/utils/inspect.py @@ -5,10 +5,8 @@ from typing import ( Any, Generic, - Optional, Protocol, TypeVar, - Union, get_args, get_origin, ) @@ -39,7 +37,7 @@ def get_func_args(func: Callable[[Any], Any]) -> list[str]: ] -def get_specialized_type_var_map(cls: type) -> Optional[dict[str, type]]: +def get_specialized_type_var_map(cls: type) -> dict[str, type] | None: """Get a type var map for specialized types. Consider the following: @@ -84,7 +82,7 @@ class IntBarFoo(IntBar, Foo[str]): ... """ from strawberry.types.base import has_object_definition - param_args: dict[TypeVar, Union[TypeVar, type]] = {} + param_args: dict[TypeVar, TypeVar | type] = {} types: list[type] = [cls] while types: diff --git a/strawberry/utils/logging.py b/strawberry/utils/logging.py index 0a701a7701..af154d7451 100644 --- a/strawberry/utils/logging.py +++ b/strawberry/utils/logging.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from typing import Final @@ -18,7 +18,7 @@ class StrawberryLogger: def error( cls, error: GraphQLError, - execution_context: Optional[ExecutionContext] = None, + execution_context: ExecutionContext | None = None, # https://www.python.org/dev/peps/pep-0484/#arbitrary-argument-lists-and-default-argument-values **logger_kwargs: Any, ) -> None: diff --git a/strawberry/utils/operation.py b/strawberry/utils/operation.py index d1acaa42d8..4741e38136 100644 --- a/strawberry/utils/operation.py +++ b/strawberry/utils/operation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from graphql.language import OperationDefinitionNode @@ -12,7 +12,7 @@ def get_first_operation( graphql_document: DocumentNode, -) -> Optional[OperationDefinitionNode]: +) -> OperationDefinitionNode | None: for definition in graphql_document.definitions: if isinstance(definition, OperationDefinitionNode): return definition @@ -21,9 +21,9 @@ def get_first_operation( def get_operation_type( - graphql_document: DocumentNode, operation_name: Optional[str] = None + graphql_document: DocumentNode, operation_name: str | None = None ) -> OperationType: - definition: Optional[OperationDefinitionNode] = None + definition: OperationDefinitionNode | None = None if operation_name is not None: for d in graphql_document.definitions: diff --git a/strawberry/utils/typing.py b/strawberry/utils/typing.py index db1dcf79ce..0b1b62ff38 100644 --- a/strawberry/utils/typing.py +++ b/strawberry/utils/typing.py @@ -11,7 +11,6 @@ ClassVar, ForwardRef, Generic, - Optional, TypeGuard, TypeVar, Union, @@ -109,7 +108,7 @@ def get_optional_annotation(annotation: type) -> type: # if we have multiple non none types we want to return a copy of this # type (normally a Union type). if len(non_none_types) > 1: - return Union[non_none_types] # type: ignore + return Union[non_none_types] # type: ignore # noqa: UP007 return non_none_types[0] @@ -145,7 +144,7 @@ def is_type_var(annotation: type) -> bool: return isinstance(annotation, TypeVar) -def is_classvar(cls: type, annotation: Union[ForwardRef, str]) -> bool: +def is_classvar(cls: type, annotation: ForwardRef | str) -> bool: """Returns True if the annotation is a ClassVar.""" # This code was copied from the dataclassses cpython implementation to check # if a field is annotated with ClassVar or not, taking future annotations @@ -173,7 +172,7 @@ def type_has_annotation(type_: object, annotation: type) -> bool: return False -def get_parameters(annotation: type) -> Union[tuple[object], tuple[()]]: +def get_parameters(annotation: type) -> tuple[object] | tuple[()]: if isinstance(annotation, _GenericAlias) or ( isinstance(annotation, type) and issubclass(annotation, Generic) @@ -184,9 +183,9 @@ def get_parameters(annotation: type) -> Union[tuple[object], tuple[()]]: def _get_namespace_from_ast( - expr: Union[ast.Expr, ast.expr], - globalns: Optional[dict] = None, - localns: Optional[dict] = None, + expr: ast.Expr | ast.expr, + globalns: dict | None = None, + localns: dict | None = None, ) -> dict[str, type]: from strawberry.types.lazy_type import StrawberryLazyReference @@ -249,8 +248,8 @@ def _get_namespace_from_ast( def eval_type( type_: Any, - globalns: Optional[dict] = None, - localns: Optional[dict] = None, + globalns: dict | None = None, + localns: dict | None = None, ) -> type: """Evaluates a type, resolving forward references.""" from strawberry.parent import StrawberryParent diff --git a/tests/a.py b/tests/a.py index fc55f1361c..bf100f196a 100644 --- a/tests/a.py +++ b/tests/a.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Optional +from typing import TYPE_CHECKING, Annotated import strawberry @@ -25,7 +25,7 @@ async def optional_b(self) -> Annotated[B, strawberry.lazy("tests.b")] | None: return B(id=self.id) @strawberry.field - async def optional_b2(self) -> Optional[Annotated[B, strawberry.lazy("tests.b")]]: + async def optional_b2(self) -> Annotated[B, strawberry.lazy("tests.b")] | None: from tests.b import B return B(id=self.id) diff --git a/tests/asgi/test_async.py b/tests/asgi/test_async.py index 81fc9d19f4..414c4799ad 100644 --- a/tests/asgi/test_async.py +++ b/tests/asgi/test_async.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest @@ -19,7 +19,7 @@ def test_client() -> TestClient: @strawberry.type class Query: @strawberry.field - async def hello(self, name: Optional[str] = None) -> str: + async def hello(self, name: str | None = None) -> str: return f"Hello {name or 'world'}" async_schema = strawberry.Schema(Query) diff --git a/tests/b.py b/tests/b.py index 2e9b83e1bf..b959d8d63c 100644 --- a/tests/b.py +++ b/tests/b.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Optional +from typing import TYPE_CHECKING, Annotated import strawberry @@ -37,7 +37,7 @@ async def optional_a( @strawberry.field async def optional_a2( self, - ) -> Optional[Annotated[A, strawberry.lazy("tests.a"), object()]]: + ) -> Annotated[A, strawberry.lazy("tests.a"), object()] | None: from tests.a import A return A(id=self.id) diff --git a/tests/benchmarks/schema.py b/tests/benchmarks/schema.py index decf6b3927..96c75c4614 100644 --- a/tests/benchmarks/schema.py +++ b/tests/benchmarks/schema.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Annotated, Union +from typing import Annotated import strawberry @@ -151,9 +151,7 @@ class CommentEdge: cursor: str -SearchResult = Annotated[ - Union[User, Post, Comment], strawberry.union(name="SearchResult") -] +SearchResult = Annotated[User | Post | Comment, strawberry.union(name="SearchResult")] @strawberry.type diff --git a/tests/benchmarks/test_generic_input.py b/tests/benchmarks/test_generic_input.py index 7be84c567a..472df7f3dd 100644 --- a/tests/benchmarks/test_generic_input.py +++ b/tests/benchmarks/test_generic_input.py @@ -1,5 +1,5 @@ import asyncio -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from pytest_codspeed.plugin import BenchmarkFixture @@ -12,15 +12,15 @@ class GraphQLFilter(Generic[T]): """EXTERNAL Filter for GraphQL queries""" - eq: Optional[T] = None - in_: Optional[list[T]] = None - nin: Optional[list[T]] = None - gt: Optional[T] = None - gte: Optional[T] = None - lt: Optional[T] = None - lte: Optional[T] = None - contains: Optional[T] = None - icontains: Optional[T] = None + eq: T | None = None + in_: list[T] | None = None + nin: list[T] | None = None + gt: T | None = None + gte: T | None = None + lt: T | None = None + lte: T | None = None + contains: T | None = None + icontains: T | None = None @strawberry.type @@ -35,7 +35,7 @@ class Book: @strawberry.field async def authors( self, - name: Optional[GraphQLFilter[str]] = None, + name: GraphQLFilter[str] | None = None, ) -> list[Author]: return [Author(name="F. Scott Fitzgerald")] diff --git a/tests/codegen/conftest.py b/tests/codegen/conftest.py index f1cb598958..2893766e79 100644 --- a/tests/codegen/conftest.py +++ b/tests/codegen/conftest.py @@ -7,9 +7,7 @@ Annotated, Generic, NewType, - Optional, TypeVar, - Union, ) from uuid import UUID @@ -52,7 +50,7 @@ class LifeContainer(Generic[LivingThing1, LivingThing2]): items2: list[LivingThing2] -PersonOrAnimal = Annotated[Union[Person, Animal], strawberry.union("PersonOrAnimal")] +PersonOrAnimal = Annotated[Person | Animal, strawberry.union("PersonOrAnimal")] @strawberry.interface @@ -77,7 +75,7 @@ class Image(Node): @strawberry.input class PersonInput: name: str - age: Optional[int] = strawberry.UNSET + age: int | None = strawberry.UNSET @strawberry.input @@ -85,9 +83,9 @@ class ExampleInput: id: strawberry.ID name: str age: int - person: Optional[PersonInput] + person: PersonInput | None people: list[PersonInput] - optional_people: Optional[list[PersonInput]] + optional_people: list[PersonInput] | None @strawberry.type @@ -101,27 +99,27 @@ class Query: datetime: datetime.datetime time: datetime.time decimal: decimal.Decimal - optional_int: Optional[int] + optional_int: int | None list_of_int: list[int] - list_of_optional_int: list[Optional[int]] - optional_list_of_optional_int: Optional[list[Optional[int]]] + list_of_optional_int: list[int | None] + optional_list_of_optional_int: list[int | None] | None person: Person - optional_person: Optional[Person] + optional_person: Person | None list_of_people: list[Person] - optional_list_of_people: Optional[list[Person]] + optional_list_of_people: list[Person] | None enum: Color json: JSON union: PersonOrAnimal - optional_union: Optional[PersonOrAnimal] + optional_union: PersonOrAnimal | None interface: Node lazy: Annotated["LaziestType", strawberry.lazy("tests.codegen.lazy_type")] @strawberry.field - def with_inputs(self, id: Optional[strawberry.ID], input: ExampleInput) -> bool: + def with_inputs(self, id: strawberry.ID | None, input: ExampleInput) -> bool: return True @strawberry.field - def get_person_or_animal(self) -> Union[Person, Animal]: + def get_person_or_animal(self) -> Person | Animal: """Randomly get a person or an animal.""" p_or_a = random.choice([Person, Animal])() # noqa: S311 p_or_a.name = "Howard" @@ -143,7 +141,7 @@ class BlogPostInput: pi: float = 3.14159 a_bool: bool = True an_int: int = 42 - an_optional_int: Optional[int] = None + an_optional_int: int | None = None @strawberry.input diff --git a/tests/experimental/pydantic/schema/test_1_and_2.py b/tests/experimental/pydantic/schema/test_1_and_2.py index a5f4b23224..2716bdc6ba 100644 --- a/tests/experimental/pydantic/schema/test_1_and_2.py +++ b/tests/experimental/pydantic/schema/test_1_and_2.py @@ -1,6 +1,5 @@ import sys import textwrap -from typing import Optional, Union import pytest @@ -19,7 +18,7 @@ def test_can_use_both_pydantic_1_and_2(): class UserModel(pydantic.BaseModel): age: int - name: Optional[str] + name: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -28,7 +27,7 @@ class User: class LegacyUserModel(pydantic_v1.BaseModel): age: int - name: Optional[str] + name: str | None int_field: pydantic.v1.NonNegativeInt = 1 @strawberry.experimental.pydantic.type(LegacyUserModel) @@ -40,7 +39,7 @@ class LegacyUser: @strawberry.type class Query: @strawberry.field - def user(self, id: strawberry.ID) -> Union[User, LegacyUser]: + def user(self, id: strawberry.ID) -> User | LegacyUser: if id == "legacy": return LegacyUser(age=1, name="legacy") diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index 91e6d7317e..0e63171e74 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -1,6 +1,5 @@ import textwrap from enum import Enum -from typing import Optional, Union import pydantic @@ -11,7 +10,7 @@ def test_basic_type_field_list(): class UserModel(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -50,7 +49,7 @@ def user(self) -> User: def test_all_fields(): class UserModel(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel, all_fields=True) class User: @@ -88,7 +87,7 @@ def user(self) -> User: def test_auto_fields(): class UserModel(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None other: float @strawberry.experimental.pydantic.type(UserModel) @@ -128,7 +127,7 @@ def user(self) -> User: def test_basic_alias_type(): class UserModel(pydantic.BaseModel): age_: int = pydantic.Field(..., alias="age") - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -328,7 +327,7 @@ class BranchB(pydantic.BaseModel): field_b: int class User(pydantic.BaseModel): - union_field: Union[BranchA, BranchB] + union_field: BranchA | BranchB @strawberry.experimental.pydantic.type(BranchA) class BranchAType: @@ -366,7 +365,7 @@ class BranchB(pydantic.BaseModel): field_b: int class User(pydantic.BaseModel): - union_field: Union[BranchA, BranchB] + union_field: BranchA | BranchB @strawberry.experimental.pydantic.type(BranchA) class BranchAType: @@ -477,7 +476,7 @@ def user(self) -> UserType: def test_basic_type_with_optional_and_default(): class UserModel(pydantic.BaseModel): age: int - password: Optional[str] = pydantic.Field(default="ABC") + password: str | None = pydantic.Field(default="ABC") @strawberry.experimental.pydantic.type(UserModel, all_fields=True) class User: diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index be761ae87c..8f5d26125d 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -1,5 +1,4 @@ import textwrap -from typing import Optional import pydantic @@ -12,7 +11,7 @@ def test_field_type_default(): class User(pydantic.BaseModel): name: str = "James" - nickname: Optional[str] = "Jim" + nickname: str | None = "Jim" @strawberry.experimental.pydantic.type(User, all_fields=True) class PydanticUser: ... @@ -55,7 +54,7 @@ def b(self) -> StrawberryUser: def test_pydantic_type_default_none(): class UserPydantic(pydantic.BaseModel): - name: Optional[str] = None + name: str | None = None @strawberry.experimental.pydantic.type(UserPydantic, all_fields=True) class User: ... @@ -82,7 +81,7 @@ class Query: def test_pydantic_type_no_default_but_optional(): class UserPydantic(pydantic.BaseModel): # pydantic automatically adds a default of None for Optional fields - name: Optional[str] + name: str | None @strawberry.experimental.pydantic.type(UserPydantic, all_fields=True) class User: ... @@ -151,7 +150,7 @@ def b(self, user: StrawberryUser) -> str: @needs_pydantic_v2 def test_v2_explicit_default(): class User(pydantic.BaseModel): - name: Optional[str] + name: str | None @strawberry.experimental.pydantic.type(User, all_fields=True) class PydanticUser: ... @@ -182,7 +181,7 @@ def a(self) -> PydanticUser: def test_v2_input_with_nonscalar_default(): class NonScalarType(pydantic.BaseModel): id: int = 10 - nullable_field: Optional[int] = None + nullable_field: int | None = None class Owning(pydantic.BaseModel): non_scalar_type: NonScalarType = NonScalarType() @@ -200,7 +199,7 @@ class OwningInput: ... class ExampleOutput: owning_id: int non_scalar_id: int - non_scalar_nullable_field: Optional[int] + non_scalar_nullable_field: int | None @strawberry.type class Query: diff --git a/tests/experimental/pydantic/schema/test_federation.py b/tests/experimental/pydantic/schema/test_federation.py index db94a8e336..6bb949bf66 100644 --- a/tests/experimental/pydantic/schema/test_federation.py +++ b/tests/experimental/pydantic/schema/test_federation.py @@ -1,5 +1,3 @@ -import typing - from pydantic import BaseModel import strawberry @@ -26,7 +24,7 @@ def resolve_reference(cls, upc) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/experimental/pydantic/schema/test_forward_reference.py b/tests/experimental/pydantic/schema/test_forward_reference.py index ebc94d4b37..2543bc7392 100644 --- a/tests/experimental/pydantic/schema/test_forward_reference.py +++ b/tests/experimental/pydantic/schema/test_forward_reference.py @@ -1,7 +1,6 @@ from __future__ import annotations import textwrap -from typing import Optional import pydantic @@ -13,7 +12,7 @@ def test_auto_fields(): class UserModel(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None other: float @strawberry.experimental.pydantic.type(UserModel) diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index d225d75523..43c032884f 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -1,5 +1,3 @@ -from typing import Union - import pydantic import strawberry @@ -179,7 +177,7 @@ class Query: @strawberry.type class Mutation: @strawberry.mutation - def create_user(self, input: CreateUserInput) -> Union[UserType, UserError]: + def create_user(self, input: CreateUserInput) -> UserType | UserError: try: data = input.to_pydantic() except pydantic.ValidationError as e: diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 99fd440042..27932687d7 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -1,6 +1,6 @@ import dataclasses from enum import Enum -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any import pydantic import pytest @@ -20,7 +20,7 @@ def test_basic_type_field_list(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None with pytest.deprecated_call(): @@ -46,7 +46,7 @@ class UserType: def test_basic_type_all_fields(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(User, all_fields=True) class UserType: @@ -71,7 +71,7 @@ class UserType: def test_basic_type_all_fields_warn(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None with pytest.raises( UserWarning, @@ -86,7 +86,7 @@ class UserType: def test_basic_type_auto_fields(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None other: float @strawberry.experimental.pydantic.type(User) @@ -115,7 +115,7 @@ class OtherSentinel: class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None other: int @strawberry.experimental.pydantic.type(User) @@ -149,7 +149,7 @@ class Group(pydantic.BaseModel): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None group: Group with pytest.raises( @@ -170,7 +170,7 @@ class Group(pydantic.BaseModel): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None group: Group @strawberry.experimental.pydantic.type(Group) @@ -241,7 +241,7 @@ class Friend(pydantic.BaseModel): name: str class User(pydantic.BaseModel): - friends: Optional[list[Optional[Friend]]] + friends: list[Friend | None] | None @strawberry.experimental.pydantic.type(Friend) class FriendType: @@ -266,7 +266,7 @@ class UserType: def test_basic_type_without_fields_throws_an_error(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None with pytest.raises(MissingFieldsListError): @@ -278,7 +278,7 @@ class UserType: def test_type_with_fields_coming_from_strawberry_and_pydantic(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(User) class UserType: @@ -304,7 +304,7 @@ class UserType: def test_default_and_default_factory(): class User1(pydantic.BaseModel): - friend: Optional[str] = "friend_value" + friend: str | None = "friend_value" @strawberry.experimental.pydantic.type(User1) class UserType1: @@ -314,7 +314,7 @@ class UserType1: assert UserType1().to_pydantic().friend == "friend_value" class User2(pydantic.BaseModel): - friend: Optional[str] = None + friend: str | None = None @strawberry.experimental.pydantic.type(User2) class UserType2: @@ -326,7 +326,7 @@ class UserType2: # Test instantiation using default_factory class User3(pydantic.BaseModel): - friend: Optional[str] = pydantic.Field(default_factory=lambda: "friend_value") + friend: str | None = pydantic.Field(default_factory=lambda: "friend_value") @strawberry.experimental.pydantic.type(User3) class UserType3: @@ -336,7 +336,7 @@ class UserType3: assert UserType3().to_pydantic().friend == "friend_value" class User4(pydantic.BaseModel): - friend: Optional[str] = pydantic.Field(default_factory=lambda: None) + friend: str | None = pydantic.Field(default_factory=lambda: None) @strawberry.experimental.pydantic.type(User4) class UserType4: @@ -350,10 +350,10 @@ def test_optional_and_default(): class UserModel(pydantic.BaseModel): age: int name: str = pydantic.Field("Michael", description="The user name") - password: Optional[str] = pydantic.Field(default="ABC") - passwordtwo: Optional[str] = None - some_list: Optional[list[str]] = pydantic.Field(default_factory=list) - check: Optional[bool] = False + password: str | None = pydantic.Field(default="ABC") + passwordtwo: str | None = None + some_list: list[str] | None = pydantic.Field(default_factory=list) + check: bool | None = False @strawberry.experimental.pydantic.type(UserModel, all_fields=True) class User: @@ -432,7 +432,7 @@ class UserType: def test_type_with_fields_coming_from_strawberry_and_pydantic_with_default(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(User) class UserType: @@ -465,7 +465,7 @@ class Name: class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(User) class UserType: @@ -492,7 +492,7 @@ class UserType: def test_type_with_aliased_pydantic_field(): class UserModel(pydantic.BaseModel): age_: int = pydantic.Field(..., alias="age") - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -522,7 +522,7 @@ class BranchB(pydantic.BaseModel): class User(pydantic.BaseModel): age: int - union_field: Union[BranchA, BranchB] + union_field: BranchA | BranchB @strawberry.experimental.pydantic.type(BranchA) class BranchAType: @@ -629,7 +629,7 @@ class Work(pydantic.BaseModel): class User(pydantic.BaseModel): name: str # Note that pydantic v2 requires an explicit default of None for Optionals - work: Optional[Work] = None + work: Work | None = None class Group(pydantic.BaseModel): users: list[User] @@ -738,7 +738,7 @@ class UserType: def test_type_with_aliased_pydantic_field_changed_type(): class UserModel(pydantic.BaseModel): age_: int = pydantic.Field(..., alias="age") - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -762,7 +762,7 @@ class User: def test_deprecated_fields(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None other: float @strawberry.experimental.pydantic.type(User) @@ -797,7 +797,7 @@ def has_permission( class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None other: float @strawberry.experimental.pydantic.type(User) @@ -828,7 +828,7 @@ class Sensitive: class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None other: float @strawberry.experimental.pydantic.type(User) @@ -937,8 +937,8 @@ class UserType: def test_nested_annotated(): class User(pydantic.BaseModel): - a: Optional[Annotated[int, "metadata"]] - b: Optional[list[Annotated[int, "metadata"]]] + a: Annotated[int, "metadata"] | None + b: list[Annotated[int, "metadata"]] | None @strawberry.experimental.pydantic.input(User, all_fields=True) class UserType: diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index cdc2dd23d2..b4b486bc88 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -2,7 +2,7 @@ import dataclasses import re from enum import Enum -from typing import Any, NewType, Optional, TypeVar, Union +from typing import Any, NewType, TypeVar import pytest from pydantic import BaseModel, Field, ValidationError @@ -32,7 +32,7 @@ def test_can_use_type_standalone(): class User(BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(User) class UserType: @@ -48,7 +48,7 @@ class UserType: def test_can_convert_pydantic_type_to_strawberry(): class User(BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(User) class UserType: @@ -127,7 +127,7 @@ class UserType: def test_can_convert_alias_pydantic_field_to_strawberry(): class UserModel(BaseModel): age_: int = Field(..., alias="age") - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -144,7 +144,7 @@ class User: def test_convert_alias_name(): class UserModel(BaseModel): age_: int = Field(..., alias="age") - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type( UserModel, all_fields=True, use_pydantic_alias=True @@ -162,7 +162,7 @@ class User: ... def test_do_not_convert_alias_name(): class UserModel(BaseModel): age_: int = Field(..., alias="age") - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type( UserModel, all_fields=True, use_pydantic_alias=False @@ -180,7 +180,7 @@ class User: ... def test_can_pass_pydantic_field_description_to_strawberry(): class UserModel(BaseModel): age: int - password: Optional[str] = Field(..., description="NOT 'password'.") + password: str | None = Field(..., description="NOT 'password'.") @strawberry.experimental.pydantic.type(UserModel) class User: @@ -382,7 +382,7 @@ class BranchB(BaseModel): class User(BaseModel): age: int - union_field: Union[BranchA, BranchB] + union_field: BranchA | BranchB @strawberry.experimental.pydantic.type(BranchA) class BranchAType: @@ -423,7 +423,7 @@ class BranchB: class User(BaseModel): age: int - union_field: Union[BranchA, BranchB] + union_field: BranchA | BranchB @strawberry.experimental.pydantic.type(User) class UserType: @@ -454,7 +454,7 @@ class BranchB(BaseModel): class User(BaseModel): age: int - union_field: Union[None, BranchA, BranchB] + union_field: None | BranchA | BranchB @strawberry.experimental.pydantic.type(BranchA) class BranchAType: @@ -560,7 +560,7 @@ class UserType: def test_can_convert_pydantic_type_to_strawberry_with_additional_fields(): class UserModel(BaseModel): - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -580,7 +580,7 @@ class Work: name: str class UserModel(BaseModel): - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -600,7 +600,7 @@ class Work: name: str class UserModel(BaseModel): - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -667,7 +667,7 @@ class Work: name: strawberry.auto class UserModel(BaseModel): - work: list[Optional[WorkModel]] + work: list[WorkModel | None] @strawberry.experimental.pydantic.type(UserModel) class User: @@ -707,7 +707,7 @@ class Work: year: int class UserModel(BaseModel): - work: Optional[WorkModel] + work: WorkModel | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -724,7 +724,7 @@ class User: def test_can_convert_pydantic_type_to_strawberry_with_optional_nested_value(): class UserModel(BaseModel): - names: Optional[list[str]] + names: list[str] | None @strawberry.experimental.pydantic.type(UserModel) class User: @@ -742,7 +742,7 @@ class User: def test_can_convert_input_types_to_pydantic(): class User(BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.input(User) class UserInput: @@ -759,7 +759,7 @@ class UserInput: def test_can_convert_input_types_to_pydantic_default_values(): class User(BaseModel): age: int - password: Optional[str] = None + password: str | None = None @strawberry.experimental.pydantic.input(User) class UserInput: @@ -776,7 +776,7 @@ class UserInput: def test_can_convert_input_types_to_pydantic_default_values_defaults_declared_first(): # test that we can declare a field with a default. before a field without a default class User(BaseModel): - password: Optional[str] = None + password: str | None = None age: int @strawberry.experimental.pydantic.input(User) @@ -811,7 +811,7 @@ def test_can_convert_pydantic_type_to_strawberry_newtype(): class User(BaseModel): age: int - password: Optional[Password] + password: Password | None @strawberry.experimental.pydantic.type(User) class UserType: @@ -911,9 +911,7 @@ def test_convert_input_types_to_pydantic_default_and_default_factory(): ): class User(BaseModel): - password: Optional[str] = Field( - default=None, default_factory=lambda: None - ) + password: str | None = Field(default=None, default_factory=lambda: None) else: with pytest.raises( @@ -922,9 +920,7 @@ class User(BaseModel): ): class User(BaseModel): - password: Optional[str] = Field( - default=None, default_factory=lambda: None - ) + password: str | None = Field(default=None, default_factory=lambda: None) def test_can_convert_pydantic_type_to_strawberry_with_additional_field_resolvers(): @@ -932,7 +928,7 @@ def some_resolver() -> int: return 84 class UserModel(BaseModel): - password: Optional[str] + password: str | None new_age: int @strawberry.experimental.pydantic.type(UserModel) @@ -959,7 +955,7 @@ class Work(BaseModel): class User(BaseModel): name: str - work: Optional[Work] + work: Work | None class Group(BaseModel): users: list[User] @@ -1009,7 +1005,7 @@ class GroupOutput: def test_custom_conversion_functions(): class User(BaseModel): age: int - password: Optional[str] + password: str | None @strawberry.experimental.pydantic.type(User) class UserType: @@ -1018,7 +1014,7 @@ class UserType: @staticmethod def from_pydantic( - instance: User, extra: Optional[dict[str, Any]] = None + instance: User, extra: dict[str, Any] | None = None ) -> "UserType": return UserType( age=str(instance.age), @@ -1048,7 +1044,7 @@ def to_pydantic(self) -> User: def test_nested_custom_conversion_functions(): class User(BaseModel): age: int - password: Optional[str] + password: str | None class Parent(BaseModel): user: User @@ -1060,7 +1056,7 @@ class UserType: @staticmethod def from_pydantic( - instance: User, extra: Optional[dict[str, Any]] = None + instance: User, extra: dict[str, Any] | None = None ) -> "UserType": return UserType( age=str(instance.age), @@ -1099,7 +1095,7 @@ class Work: class User(BaseModel): age: int - password: Optional[str] + password: str | None work: Work @strawberry.experimental.pydantic.input(User) @@ -1122,7 +1118,7 @@ class Work(BaseModel): class User(BaseModel): age: int - password: Optional[str] + password: str | None work: dict[str, Work] @strawberry.experimental.pydantic.input(Work) diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index ce5893ca02..cd50cfe140 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -1,5 +1,3 @@ -from typing import Optional - import pydantic import pytest @@ -93,7 +91,7 @@ class UserError: def test_basic_type_all_fields_warn(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None with pytest.raises( UserWarning, @@ -108,7 +106,7 @@ class UserError: def test_basic_error_type_without_fields_throws_an_error(): class User(pydantic.BaseModel): age: int - password: Optional[str] + password: str | None with pytest.raises(MissingFieldsListError): @@ -220,7 +218,7 @@ class UserError: def test_error_type_with_optional_field(): class UserModel(pydantic.BaseModel): - age: Optional[int] + age: int | None @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -239,7 +237,7 @@ class UserError: def test_error_type_with_list_of_optional_scalar(): class UserModel(pydantic.BaseModel): - age: list[Optional[int]] + age: list[int | None] @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -260,7 +258,7 @@ class UserError: def test_error_type_with_optional_list_scalar(): class UserModel(pydantic.BaseModel): - age: Optional[list[int]] + age: list[int] | None @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -281,7 +279,7 @@ class UserError: def test_error_type_with_optional_list_of_optional_scalar(): class UserModel(pydantic.BaseModel): - age: Optional[list[Optional[int]]] + age: list[int | None] | None @strawberry.experimental.pydantic.error_type(UserModel) class UserError: @@ -309,7 +307,7 @@ class FriendError: name: strawberry.auto class UserModel(pydantic.BaseModel): - friends: Optional[list[FriendModel]] + friends: list[FriendModel] | None @strawberry.experimental.pydantic.error_type(UserModel) class UserError: diff --git a/tests/fastapi/app.py b/tests/fastapi/app.py index 26f0d7eb61..17b0864125 100644 --- a/tests/fastapi/app.py +++ b/tests/fastapi/app.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any from fastapi import BackgroundTasks, Depends, FastAPI, Request, WebSocket from strawberry.fastapi import GraphQLRouter @@ -24,7 +24,7 @@ async def get_context( async def get_root_value( request: Request = None, ws: WebSocket = None -) -> Union[Request, WebSocket]: +) -> Request | WebSocket: return request or ws diff --git a/tests/fastapi/test_async.py b/tests/fastapi/test_async.py index aa134634bf..8c867631fe 100644 --- a/tests/fastapi/test_async.py +++ b/tests/fastapi/test_async.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest @@ -19,7 +19,7 @@ def test_client() -> TestClient: @strawberry.type class Query: @strawberry.field - async def hello(self, name: Optional[str] = None) -> str: + async def hello(self, name: str | None = None) -> str: return f"Hello {name or 'world'}" async_schema = strawberry.Schema(Query) diff --git a/tests/federation/printer/test_inaccessible.py b/tests/federation/printer/test_inaccessible.py index 01fc8d9eda..6c859a4c69 100644 --- a/tests/federation/printer/test_inaccessible.py +++ b/tests/federation/printer/test_inaccessible.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Annotated, Union +from typing import Annotated import strawberry @@ -249,9 +249,7 @@ class A: class B: b: str - MyUnion = Annotated[ - Union[A, B], strawberry.federation.union("Union", inaccessible=True) - ] + MyUnion = Annotated[A | B, strawberry.federation.union("Union", inaccessible=True)] @strawberry.federation.type class Query: diff --git a/tests/federation/printer/test_one_of.py b/tests/federation/printer/test_one_of.py index 5de1938007..4e39de1a4c 100644 --- a/tests/federation/printer/test_one_of.py +++ b/tests/federation/printer/test_one_of.py @@ -1,5 +1,4 @@ import textwrap -from typing import Optional import strawberry @@ -7,8 +6,8 @@ def test_prints_one_of_directive(): @strawberry.federation.input(one_of=True, tags=["myTag", "anotherTag"]) class Input: - a: Optional[str] = strawberry.UNSET - b: Optional[int] = strawberry.UNSET + a: str | None = strawberry.UNSET + b: int | None = strawberry.UNSET @strawberry.federation.type class Query: diff --git a/tests/federation/printer/test_tag.py b/tests/federation/printer/test_tag.py index 69e87fb342..a9e90fa4ab 100644 --- a/tests/federation/printer/test_tag.py +++ b/tests/federation/printer/test_tag.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Annotated, Union +from typing import Annotated import strawberry @@ -168,7 +168,7 @@ class B: b: str MyUnion = Annotated[ - Union[A, B], strawberry.federation.union("Union", tags=["myTag", "anotherTag"]) + A | B, strawberry.federation.union("Union", tags=["myTag", "anotherTag"]) ] @strawberry.federation.type diff --git a/tests/federation/test_entities.py b/tests/federation/test_entities.py index 73d9d1d102..51aabf83f9 100644 --- a/tests/federation/test_entities.py +++ b/tests/federation/test_entities.py @@ -1,5 +1,3 @@ -import typing - from graphql import located_error import strawberry @@ -18,7 +16,7 @@ def resolve_reference(cls, upc) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -58,7 +56,7 @@ def resolve_reference(cls, info: strawberry.Info, upc: str) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -102,7 +100,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -142,7 +140,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -189,7 +187,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -232,7 +230,7 @@ class Product: @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -280,7 +278,7 @@ def resolve_reference(cls, id: strawberry.ID) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def mock(self) -> typing.Optional[Product]: + def mock(self) -> Product | None: return None schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -330,7 +328,7 @@ def resolve_reference(cls, id: strawberry.ID) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def mock(self) -> typing.Optional[Product]: + def mock(self) -> Product | None: return None schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -383,7 +381,7 @@ def resolve_reference(cls, info: Info, id: strawberry.ID) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def mock(self) -> typing.Optional[Product]: + def mock(self) -> Product | None: return None schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -430,7 +428,7 @@ async def resolve_reference(cls, upc: str) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) @@ -469,7 +467,7 @@ async def resolve_reference(cls, upc: str) -> "Product": @strawberry.federation.type(extend=True) class Query: @strawberry.field - def top_products(self, first: int) -> typing.List[Product]: # pragma: no cover + def top_products(self, first: int) -> list[Product]: # pragma: no cover return [] schema = strawberry.federation.Schema(query=Query, enable_federation_2=True) diff --git a/tests/federation/test_schema.py b/tests/federation/test_schema.py index f50d37eace..083ae4c296 100644 --- a/tests/federation/test_schema.py +++ b/tests/federation/test_schema.py @@ -1,6 +1,6 @@ import textwrap import warnings -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar import pytest @@ -11,9 +11,9 @@ def test_entities_type_when_no_type_has_keys(): @strawberry.federation.type() class Product: upc: str - name: Optional[str] - price: Optional[int] - weight: Optional[int] + name: str | None + price: int | None + weight: int | None @strawberry.federation.type(extend=True) class Query: @@ -67,9 +67,9 @@ def test_entities_type(): @strawberry.federation.type(keys=["upc"]) class Product: upc: str - name: Optional[str] - price: Optional[int] - weight: Optional[int] + name: str | None + price: int | None + weight: int | None @strawberry.federation.type(extend=True) class Query: @@ -290,9 +290,9 @@ def test_can_create_schema_without_query(): @strawberry.federation.type() class Product: upc: str - name: Optional[str] - price: Optional[int] - weight: Optional[int] + name: str | None + price: int | None + weight: int | None schema = strawberry.federation.Schema(types=[Product], enable_federation_2=True) @@ -325,9 +325,9 @@ def test_federation_schema_warning(): @strawberry.federation.type(keys=["upc"]) class ProductFed: upc: str - name: Optional[str] - price: Optional[int] - weight: Optional[int] + name: str | None + price: int | None + weight: int | None with pytest.warns(UserWarning) as record: # noqa: PT030 strawberry.Schema( @@ -345,9 +345,9 @@ def test_does_not_warn_when_using_federation_schema(): @strawberry.federation.type(keys=["upc"]) class ProductFed: upc: str - name: Optional[str] - price: Optional[int] - weight: Optional[int] + name: str | None + price: int | None + weight: int | None @strawberry.type class Query: diff --git a/tests/fields/test_arguments.py b/tests/fields/test_arguments.py index 5e7959b789..f538ac6e76 100644 --- a/tests/fields/test_arguments.py +++ b/tests/fields/test_arguments.py @@ -1,4 +1,4 @@ -from typing import Annotated, Optional, Union +from typing import Annotated import pytest @@ -16,7 +16,7 @@ def test_basic_arguments(): class Query: @strawberry.field def name( - self, argument: str, optional_argument: Optional[str] + self, argument: str, optional_argument: str | None ) -> str: # pragma: no cover return "Name" @@ -45,7 +45,7 @@ class Input: class Query: @strawberry.field def name( - self, input: Input, optional_input: Optional[Input] + self, input: Input, optional_input: Input | None ) -> str: # pragma: no cover return input.name @@ -96,7 +96,7 @@ class Input: @strawberry.type class Query: @strawberry.field - def names(self, inputs: list[Optional[Input]]) -> list[str]: # pragma: no cover + def names(self, inputs: list[Input | None]) -> list[str]: # pragma: no cover return [input_.name for input_ in inputs if input_ is not None] definition = Query.__strawberry_definition__ @@ -114,7 +114,7 @@ def names(self, inputs: list[Optional[Input]]) -> list[str]: # pragma: no cover def test_basic_arguments_on_resolver(): def name_resolver( # pragma: no cover - id: strawberry.ID, argument: str, optional_argument: Optional[str] + id: strawberry.ID, argument: str, optional_argument: str | None ) -> str: return "Name" @@ -141,7 +141,7 @@ class Query: def test_arguments_when_extending_a_type(): def name_resolver( - id: strawberry.ID, argument: str, optional_argument: Optional[str] + id: strawberry.ID, argument: str, optional_argument: str | None ) -> str: # pragma: no cover return "Name" @@ -214,7 +214,7 @@ def test_argument_with_default_value_none(): @strawberry.type class Query: @strawberry.field - def name(self, argument: Optional[str] = None) -> str: # pragma: no cover + def name(self, argument: str | None = None) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -235,7 +235,7 @@ def test_argument_with_default_value_undefined(): @strawberry.type class Query: @strawberry.field - def name(self, argument: Optional[str]) -> str: # pragma: no cover + def name(self, argument: str | None) -> str: # pragma: no cover return "Name" definition = Query.__strawberry_definition__ @@ -282,7 +282,7 @@ class Query: @strawberry.field def name( # type: ignore argument: Annotated[ - Optional[str], + str | None, strawberry.argument(description="This is a description"), ], ) -> str: # pragma: no cover @@ -438,7 +438,7 @@ class Noun: class Verb: text: str - Word = Annotated[Union[Noun, Verb], strawberry.argument("Word")] + Word = Annotated[Noun | Verb, strawberry.argument("Word")] @strawberry.field def add_word(word: Word) -> bool: diff --git a/tests/fixtures/sample_package/sample_module.py b/tests/fixtures/sample_package/sample_module.py index 44574467fc..0de153ab0b 100644 --- a/tests/fixtures/sample_package/sample_module.py +++ b/tests/fixtures/sample_package/sample_module.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Annotated, NewType, Union +from typing import Annotated, NewType import strawberry @@ -20,7 +20,7 @@ class B: a: A -UnionExample = Annotated[Union[A, B], strawberry.union("UnionExample")] +UnionExample = Annotated[A | B, strawberry.union("UnionExample")] class SampleClass: @@ -41,7 +41,7 @@ class User: role: Role example_scalar: ExampleScalar union_example: UnionExample - inline_union: Annotated[Union[A, B], strawberry.union("InlineUnion")] + inline_union: Annotated[A | B, strawberry.union("InlineUnion")] @strawberry.type diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index e4865804dd..1331230274 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator, Mapping, Sequence from datetime import timedelta from io import BytesIO -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from aiohttp import web from aiohttp.client_ws import ClientWebSocketResponse @@ -39,7 +39,7 @@ class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]): graphql_ws_handler_class = DebuggableGraphQLWSHandler async def get_context( - self, request: web.Request, response: Union[web.Response, web.WebSocketResponse] + self, request: web.Request, response: web.Response | web.WebSocketResponse ) -> dict[str, object]: context = await super().get_context(request, response) @@ -62,8 +62,8 @@ class AioHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -94,12 +94,12 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: async with TestClient(TestServer(self.app)) as client: @@ -136,7 +136,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: async with TestClient(TestServer(self.app)) as client: response = await getattr(client, method)(url, headers=headers) @@ -150,16 +150,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: async with TestClient(TestServer(self.app)) as client: response = await client.post( @@ -189,7 +189,7 @@ async def ws_connect( class AioWebSocketClient(WebSocketClient): def __init__(self, ws: ClientWebSocketResponse): self.ws = ws - self._reason: Optional[str] = None + self._reason: str | None = None async def send_text(self, payload: str) -> None: await self.ws.send_str(payload) @@ -200,12 +200,12 @@ async def send_json(self, payload: Mapping[str, object]) -> None: async def send_bytes(self, payload: bytes) -> None: await self.ws.send_bytes(payload) - async def receive(self, timeout: Optional[float] = None) -> Message: + async def receive(self, timeout: float | None = None) -> Message: m = await self.ws.receive(timeout) self._reason = m.extra return Message(type=m.type, data=m.data, extra=m.extra) - async def receive_json(self, timeout: Optional[float] = None) -> object: + async def receive_json(self, timeout: float | None = None) -> object: m = await self.ws.receive(timeout) assert m.type == WSMsgType.TEXT return json.loads(m.data) @@ -214,7 +214,7 @@ async def close(self) -> None: await self.ws.close() @property - def accepted_subprotocol(self) -> Optional[str]: + def accepted_subprotocol(self) -> str | None: return self.ws.protocol @property @@ -227,5 +227,5 @@ def close_code(self) -> int: return self.ws.close_code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: return self._reason diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 9f676691f2..cc0446e1c5 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator, Mapping, Sequence from datetime import timedelta from io import BytesIO -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from starlette.requests import Request from starlette.responses import Response as StarletteResponse @@ -42,13 +42,13 @@ class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]): graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler - async def get_root_value(self, request: Union[WebSocket, Request]) -> Query: + async def get_root_value(self, request: WebSocket | Request) -> Query: return Query() async def get_context( self, - request: Union[Request, WebSocket], - response: Union[StarletteResponse, WebSocket], + request: Request | WebSocket, + response: StarletteResponse | WebSocket, ) -> dict[str, object]: context = await super().get_context(request, response) @@ -67,8 +67,8 @@ class AsgiHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -98,12 +98,12 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -142,7 +142,7 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: response = getattr(self.client, method)(url, headers=headers) @@ -155,16 +155,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: response = self.client.post(url, headers=headers, content=data, json=json) @@ -189,8 +189,8 @@ class AsgiWebSocketClient(WebSocketClient): def __init__(self, ws: WebSocketTestSession): self.ws = ws self._closed: bool = False - self._close_code: Optional[int] = None - self._close_reason: Optional[str] = None + self._close_code: int | None = None + self._close_reason: str | None = None async def send_text(self, payload: str) -> None: self.ws.send_text(payload) @@ -201,7 +201,7 @@ async def send_json(self, payload: Mapping[str, object]) -> None: async def send_bytes(self, payload: bytes) -> None: self.ws.send_bytes(payload) - async def receive(self, timeout: Optional[float] = None) -> Message: + async def receive(self, timeout: float | None = None) -> Message: if self._closed: # if close was received via exception, fake it so that recv works return Message( @@ -217,7 +217,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: return Message(type=m["type"], data=m["text"]) return Message(type=m["type"], data=m["data"], extra=m["extra"]) - async def receive_json(self, timeout: Optional[float] = None) -> Any: + async def receive_json(self, timeout: float | None = None) -> Any: m = self.ws.receive() assert m["type"] == "websocket.send" assert "text" in m @@ -228,7 +228,7 @@ async def close(self) -> None: self._closed = True @property - def accepted_subprotocol(self) -> Optional[str]: + def accepted_subprotocol(self) -> str | None: return self.ws.accepted_subprotocol @property @@ -241,5 +241,5 @@ def close_code(self) -> int: return self._close_code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: return self._close_reason diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 5932f7664e..c263e36577 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import AsyncIterable -from typing import Optional from django.core.exceptions import BadRequest, SuspiciousOperation from django.http import Http404, HttpRequest, HttpResponse, StreamingHttpResponse @@ -45,8 +44,8 @@ class AsyncDjangoHttpClient(DjangoHttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index a282faf151..06ddd679c2 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from flask import Flask from flask import Request as FlaskRequest @@ -50,8 +50,8 @@ class AsyncFlaskHttpClient(FlaskHttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index e74f722ae1..cc24df0fbc 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -31,14 +31,14 @@ @dataclass class Response: status_code: int - data: Union[bytes, AsyncIterable[bytes]] + data: bytes | AsyncIterable[bytes] def __init__( self, status_code: int, - data: Union[bytes, AsyncIterable[bytes]], + data: bytes | AsyncIterable[bytes], *, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> None: self.status_code = status_code self.data = data @@ -66,7 +66,7 @@ async def streaming_json(self) -> AsyncIterable[JSON]: if not self.is_multipart: raise ValueError("Streaming not supported") - def parse_chunk(text: str) -> Union[JSON, None]: + def parse_chunk(text: str) -> JSON | None: # TODO: better parsing? :) with contextlib.suppress(json.JSONDecodeError): return json.loads(text) @@ -99,8 +99,8 @@ class HttpClient(abc.ABC): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -114,12 +114,12 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: ... @@ -128,34 +128,34 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: ... @abc.abstractmethod async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: ... @abc.abstractmethod async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: ... async def query( self, query: str, method: Literal["get", "post"] = "post", - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, ) -> Response: return await self._graphql_request( method, @@ -170,8 +170,8 @@ async def query( def _get_headers( self, method: Literal["get", "post"], - headers: Optional[dict[str, str]], - files: Optional[dict[str, BytesIO]], + headers: dict[str, str] | None, + files: dict[str, BytesIO] | None, ) -> dict[str, str]: additional_headers = {} headers = headers or {} @@ -188,13 +188,13 @@ def _get_headers( def _build_body( self, - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, method: Literal["get", "post"] = "post", - extensions: Optional[dict[str, Any]] = None, - ) -> Optional[dict[str, object]]: + extensions: dict[str, Any] | None = None, + ) -> dict[str, object] | None: if query is None: assert files is None assert variables is None @@ -271,7 +271,7 @@ def ws_connect( class Message: type: Any data: Any - extra: Optional[str] = None + extra: str | None = None def json(self) -> Any: return json.loads(self.data) @@ -291,17 +291,17 @@ async def send_json(self, payload: Mapping[str, object]) -> None: ... async def send_bytes(self, payload: bytes) -> None: ... @abc.abstractmethod - async def receive(self, timeout: Optional[float] = None) -> Message: ... + async def receive(self, timeout: float | None = None) -> Message: ... @abc.abstractmethod - async def receive_json(self, timeout: Optional[float] = None) -> Any: ... + async def receive_json(self, timeout: float | None = None) -> Any: ... @abc.abstractmethod async def close(self) -> None: ... @property @abc.abstractmethod - def accepted_subprotocol(self) -> Optional[str]: ... + def accepted_subprotocol(self) -> str | None: ... @property @abc.abstractmethod @@ -313,7 +313,7 @@ def close_code(self) -> int: ... @property @abc.abstractmethod - def close_reason(self) -> Optional[str]: ... + def close_reason(self) -> str | None: ... async def __aiter__(self) -> AsyncGenerator[Message, None]: while not self.closed: diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index d98868c655..ef3e7176fc 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -3,7 +3,7 @@ import urllib.parse from io import BytesIO from json import dumps -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from chalice.app import Chalice from chalice.app import Request as ChaliceRequest @@ -48,12 +48,12 @@ class ChaliceHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - schema_config: Optional[StrawberryConfig] = None, + schema_config: StrawberryConfig | None = None, ): self.app = Chalice(app_name="TheStackBadger") @@ -75,12 +75,12 @@ def handle_graphql(): async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -92,7 +92,7 @@ async def _graphql_request( extensions=extensions, ) - data: Union[dict[str, object], str, None] = None + data: dict[str, object] | str | None = None if body and files: body.update({name: (file, name) for name, file in files.items()}) @@ -124,7 +124,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: with Client(self.app) as client: response = getattr(client.http, method)(url, headers=headers) @@ -138,16 +138,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: body = dumps(json) if json is not None else data diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 62c1957327..302e05d64e 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator, Mapping, Sequence from datetime import timedelta from io import BytesIO -from typing import Any, Literal, Optional +from typing import Any, Literal from urllib3 import encode_multipart_formdata @@ -44,8 +44,8 @@ def generate_get_path( path: str, query: str, - variables: Optional[dict[str, Any]] = None, - extensions: Optional[dict[str, Any]] = None, + variables: dict[str, Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> str: body: dict[str, Any] = {"query": query} if variables is not None: @@ -148,8 +148,8 @@ class ChannelsHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -181,12 +181,12 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -219,7 +219,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, body: bytes = b"", ) -> Response: # HttpCommunicator expects tuples of bytestrings @@ -245,16 +245,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: body = b"" if data is not None: @@ -287,8 +287,8 @@ class SyncChannelsHttpClient(ChannelsHttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, @@ -304,13 +304,11 @@ def __init__( class ChannelsWebSocketClient(WebSocketClient): - def __init__( - self, client: WebsocketCommunicator, accepted_subprotocol: Optional[str] - ): + def __init__(self, client: WebsocketCommunicator, accepted_subprotocol: str | None): self.ws = client self._closed: bool = False - self._close_code: Optional[int] = None - self._close_reason: Optional[str] = None + self._close_code: int | None = None + self._close_reason: str | None = None self._accepted_subprotocol = accepted_subprotocol def name(self) -> str: @@ -325,7 +323,7 @@ async def send_json(self, payload: Mapping[str, object]) -> None: async def send_bytes(self, payload: bytes) -> None: await self.ws.send_to(bytes_data=payload) - async def receive(self, timeout: Optional[float] = None) -> Message: + async def receive(self, timeout: float | None = None) -> Message: m = await self.ws.receive_output(timeout=timeout) # type: ignore if m["type"] == "websocket.close": self._closed = True @@ -336,7 +334,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: return Message(type=m["type"], data=m["text"]) return Message(type=m["type"], data=m["data"], extra=m["extra"]) - async def receive_json(self, timeout: Optional[float] = None) -> Any: + async def receive_json(self, timeout: float | None = None) -> Any: m = await self.ws.receive_output(timeout=timeout) # type: ignore assert m["type"] == "websocket.send" assert "text" in m @@ -347,7 +345,7 @@ async def close(self) -> None: self._closed = True @property - def accepted_subprotocol(self) -> Optional[str]: + def accepted_subprotocol(self) -> str | None: return self._accepted_subprotocol @property @@ -360,5 +358,5 @@ def close_code(self) -> int: return self._close_code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: return self._close_reason diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 9bc694e4ed..4c70e7df28 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -2,7 +2,7 @@ from io import BytesIO from json import dumps -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from django.core.exceptions import BadRequest, SuspiciousOperation from django.core.files.uploadedfile import SimpleUploadedFile @@ -47,8 +47,8 @@ class DjangoHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, @@ -71,8 +71,8 @@ def _to_django_headers(self, headers: dict[str, str]) -> dict[str, str]: def _get_headers( self, method: Literal["get", "post"], - headers: Optional[dict[str, str]], - files: Optional[dict[str, BytesIO]], + headers: dict[str, str] | None, + files: dict[str, BytesIO] | None, ) -> dict[str, str]: headers = headers or {} headers = self._to_django_headers(headers) @@ -96,12 +96,12 @@ async def _do_request(self, request: HttpRequest) -> Response: async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: headers = self._get_headers(method=method, headers=headers, files=files) @@ -116,7 +116,7 @@ async def _graphql_request( extensions=extensions, ) - data: Union[dict[str, object], str, None] = None + data: dict[str, object] | str | None = None if body and files: body.update( @@ -144,7 +144,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: headers = headers or {} @@ -156,7 +156,7 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: django_headers = self._to_django_headers(headers or {}) return await self.request(url, "get", headers=django_headers) @@ -164,9 +164,9 @@ async def get( async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: headers = headers or {} content_type = headers.pop("Content-Type", "") diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 4db2dde30a..47b20b6e31 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator, Sequence from datetime import timedelta from io import BytesIO -from typing import Any, Literal, Optional +from typing import Any, Literal from fastapi import BackgroundTasks, Depends, FastAPI, Request, WebSocket from fastapi.testclient import TestClient @@ -77,8 +77,8 @@ class FastAPIHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -121,12 +121,12 @@ async def _handle_response(self, response: Any) -> Response: async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -161,7 +161,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: response = getattr(self.client, method)(url, headers=headers) @@ -170,16 +170,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: response = self.client.post(url, headers=headers, content=data, json=json) diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index bf47741bda..1b5f730e09 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -6,7 +6,7 @@ import json import urllib.parse from io import BytesIO -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from flask import Flask from flask import Request as FlaskRequest @@ -58,8 +58,8 @@ class FlaskHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, @@ -85,12 +85,12 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -102,7 +102,7 @@ async def _graphql_request( extensions=extensions, ) - data: Union[dict[str, object], str, None] = None + data: dict[str, object] | str | None = None if body and files: body.update({name: (file, name) for name, file in files.items()}) @@ -125,7 +125,7 @@ def _do_request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, **kwargs: Any, ): with self.app.test_client() as client: @@ -141,7 +141,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, **kwargs: Any, ) -> Response: loop = asyncio.get_running_loop() @@ -154,15 +154,15 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "post", headers=headers, data=data, json=json) diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 4a13e24ee3..583bf318d3 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator, Mapping, Sequence from datetime import timedelta from io import BytesIO -from typing import Any, Literal, Optional +from typing import Any, Literal from litestar import Litestar, Request from litestar.exceptions import WebSocketDisconnect @@ -52,8 +52,8 @@ class LitestarHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -99,12 +99,12 @@ async def process_result( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: if body := self._build_body( @@ -141,7 +141,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: response = getattr(self.client, method)(url, headers=headers) @@ -154,16 +154,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: response = self.client.post(url, headers=headers, content=data, json=json) @@ -188,8 +188,8 @@ class LitestarWebSocketClient(WebSocketClient): def __init__(self, ws: WebSocketTestSession): self.ws = ws self._closed: bool = False - self._close_code: Optional[int] = None - self._close_reason: Optional[str] = None + self._close_code: int | None = None + self._close_reason: str | None = None async def send_text(self, payload: str) -> None: self.ws.send_text(payload) @@ -200,7 +200,7 @@ async def send_json(self, payload: Mapping[str, object]) -> None: async def send_bytes(self, payload: bytes) -> None: self.ws.send_bytes(payload) - async def receive(self, timeout: Optional[float] = None) -> Message: + async def receive(self, timeout: float | None = None) -> Message: if self._closed: # if close was received via exception, fake it so that recv works return Message( @@ -225,7 +225,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: assert "data" in m return Message(type=m["type"], data=m["data"], extra=m["extra"]) - async def receive_json(self, timeout: Optional[float] = None) -> Any: + async def receive_json(self, timeout: float | None = None) -> Any: m = self.ws.receive() assert m["type"] == "websocket.send" assert "text" in m @@ -237,7 +237,7 @@ async def close(self) -> None: self._closed = True @property - def accepted_subprotocol(self) -> Optional[str]: + def accepted_subprotocol(self) -> str | None: return self.ws.accepted_subprotocol @property @@ -250,5 +250,5 @@ def close_code(self) -> int: return self._close_code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: return self._close_reason diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 2164d7ec11..27f6b2fc4f 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -4,7 +4,7 @@ from collections.abc import AsyncGenerator, Sequence from datetime import timedelta from io import BytesIO -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send @@ -49,14 +49,12 @@ def __init__(self, *args: Any, **kwargs: Any): self.result_override = kwargs.pop("result_override", None) super().__init__(*args, **kwargs) - async def get_root_value( - self, request: Union[QuartRequest, QuartWebsocket] - ) -> Query: + async def get_root_value(self, request: QuartRequest | QuartWebsocket) -> Query: await super().get_root_value(request) # for coverage return Query() async def get_context( - self, request: Union[QuartRequest, QuartWebsocket], response: QuartResponse + self, request: QuartRequest | QuartWebsocket, response: QuartResponse ) -> dict[str, object]: context = await super().get_context(request, response) @@ -90,8 +88,8 @@ class QuartHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, keep_alive: bool = False, keep_alive_interval: float = 1, @@ -137,12 +135,12 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -176,7 +174,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, **kwargs: Any, ) -> Response: async with self.app.test_app() as test_app, self.app.app_context(): @@ -192,16 +190,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: kwargs = {"headers": headers, "data": data, "json": json} return await self.request( diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 4ea89a6955..8d0281476d 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -3,7 +3,7 @@ from io import BytesIO from json import dumps from random import randint -from typing import Any, Literal, Optional +from typing import Any, Literal from sanic import Sanic from sanic.request import Request as SanicRequest @@ -50,8 +50,8 @@ class SanicHttpClient(HttpClient): def __init__( self, schema: Schema, - graphiql: Optional[bool] = None, - graphql_ide: Optional[GraphQL_IDE] = "graphiql", + graphiql: bool | None = None, + graphql_ide: GraphQL_IDE | None = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, @@ -72,12 +72,12 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, - operation_name: Optional[str] = None, - variables: Optional[dict[str, object]] = None, - files: Optional[dict[str, BytesIO]] = None, - headers: Optional[dict[str, str]] = None, - extensions: Optional[dict[str, Any]] = None, + query: str | None = None, + operation_name: str | None = None, + variables: dict[str, object] | None = None, + files: dict[str, BytesIO] | None = None, + headers: dict[str, str] | None = None, + extensions: dict[str, Any] | None = None, **kwargs: Any, ) -> Response: body = self._build_body( @@ -115,7 +115,7 @@ async def request( self, url: str, method: Literal["head", "get", "post", "patch", "put", "delete"], - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: request, response = await self.app.asgi_client.request( method, @@ -132,16 +132,16 @@ async def request( async def get( self, url: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ) -> Response: return await self.request(url, "get", headers=headers) async def post( self, url: str, - data: Optional[bytes] = None, - json: Optional[JSON] = None, - headers: Optional[dict[str, str]] = None, + data: bytes | None = None, + json: JSON | None = None, + headers: dict[str, str] | None = None, ) -> Response: body = dumps(json) if json is not None else data diff --git a/tests/http/test_graphql_ide.py b/tests/http/test_graphql_ide.py index b4264141fd..0db87db7e1 100644 --- a/tests/http/test_graphql_ide.py +++ b/tests/http/test_graphql_ide.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal import pytest @@ -71,7 +71,7 @@ async def test_does_not_render_graphiql_if_wrong_accept( @pytest.mark.parametrize("graphql_ide", [False, None]) async def test_renders_graphiql_disabled( http_client_class: type[HttpClient], - graphql_ide: Union[bool, None], + graphql_ide: bool | None, ): http_client = http_client_class(schema, graphql_ide=graphql_ide) response = await http_client.get("/graphql", headers={"Accept": "text/html"}) diff --git a/tests/objects/generics/test_generic_objects.py b/tests/objects/generics/test_generic_objects.py index b1c0a0e32d..f0b23c30e8 100644 --- a/tests/objects/generics/test_generic_objects.py +++ b/tests/objects/generics/test_generic_objects.py @@ -1,5 +1,5 @@ import datetime -from typing import Annotated, Generic, Optional, TypeVar, Union +from typing import Annotated, Generic, TypeVar import pytest @@ -140,8 +140,8 @@ class Value(Generic[T]): class Foo: string: Value[str] strings: Value[list[str]] - optional_string: Value[Optional[str]] - optional_strings: Value[Optional[list[str]]] + optional_string: Value[str | None] + optional_strings: Value[list[str] | None] definition = get_object_definition(Foo, strict=True) assert not definition.is_graphql_generic @@ -161,7 +161,7 @@ class Foo: def test_generic_with_optional(): @strawberry.type class Edge(Generic[T]): - node: Optional[T] + node: T | None definition = get_object_definition(Edge, strict=True) assert definition.is_graphql_generic @@ -219,7 +219,7 @@ class Connection(Generic[T]): def test_generic_with_list_of_optionals(): @strawberry.type class Connection(Generic[T]): - edges: list[Optional[T]] + edges: list[T | None] definition = get_object_definition(Connection, strict=True) assert definition.is_graphql_generic @@ -254,7 +254,7 @@ class Error: @strawberry.type class Edge(Generic[T]): - node: Union[Error, T] + node: Error | T definition = get_object_definition(Edge, strict=True) assert definition.type_params == [T] @@ -400,7 +400,7 @@ class Edge(Generic[T]): @strawberry.type class Query: - user: Optional[Edge[str]] + user: Edge[str] | None query_definition = get_object_definition(Query, strict=True) assert query_definition.type_params == [] @@ -444,7 +444,7 @@ class Edge(Generic[T]): @strawberry.type class Query: - user: Union[Edge[str], Error] + user: Edge[str] | Error query_definition = get_object_definition(Query, strict=True) assert query_definition.type_params == [] @@ -465,7 +465,7 @@ class Edge(Generic[T]): @strawberry.type class Query: - user: Union[Edge[int], Edge[str]] + user: Edge[int] | Edge[str] query_definition = get_object_definition(Query, strict=True) assert query_definition.type_params == [] @@ -499,7 +499,7 @@ class Cat: class Connection(Generic[T]): nodes: list[T] - DogCat = Annotated[Union[Dog, Cat], strawberry.union("DogCat")] + DogCat = Annotated[Dog | Cat, strawberry.union("DogCat")] @strawberry.type class Query: @@ -538,7 +538,7 @@ class Connection(Generic[T]): @strawberry.type class Query: - connection: Connection[Union[Dog, Cat]] + connection: Connection[Dog | Cat] definition = get_object_definition(Query, strict=True) assert definition.type_params == [] diff --git a/tests/objects/test_inheritance.py b/tests/objects/test_inheritance.py index ae63b11223..1db6a3d87f 100644 --- a/tests/objects/test_inheritance.py +++ b/tests/objects/test_inheritance.py @@ -1,5 +1,3 @@ -from typing import Optional - import strawberry @@ -10,6 +8,6 @@ class A: @strawberry.type class B(A): - b: Optional[str] = strawberry.field(default=None) + b: str | None = strawberry.field(default=None) assert strawberry.Schema(query=B) diff --git a/tests/relay/schema.py b/tests/relay/schema.py index 16236f0f03..3d08760141 100644 --- a/tests/relay/schema.py +++ b/tests/relay/schema.py @@ -11,7 +11,6 @@ Annotated, Any, NamedTuple, - Optional, TypeAlias, cast, ) @@ -37,7 +36,7 @@ def resolve_nodes( info: strawberry.Info, node_ids: Iterable[str], required: bool = False, - ) -> Iterable[Optional[Self]]: + ) -> Iterable[Self | None]: if node_ids is not None: return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] @@ -67,10 +66,10 @@ class FruitAsync(relay.Node): async def resolve_nodes( cls, *, - info: Optional[Info] = None, + info: Info | None = None, node_ids: Iterable[str], required: bool = False, - ) -> Iterable[Optional[Self]]: + ) -> Iterable[Self | None]: if node_ids is not None: return [ fruits_async[nid] if required else fruits_async.get(nid) @@ -95,12 +94,12 @@ def resolve_connection( cls, nodes: Iterable[Fruit], *, - info: Optional[Info] = None, - total_count: Optional[int] = None, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, + info: Info | None = None, + total_count: int | None = None, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, **kwargs: Any, ) -> Self: edges_mapping = { @@ -204,8 +203,8 @@ class Query: permission_classes=[DummyPermission] ) nodes: list[relay.Node] = relay.node() - node_optional: Optional[relay.Node] = relay.node() - nodes_optional: list[Optional[relay.Node]] = relay.node() + node_optional: relay.Node | None = relay.node() + nodes_optional: list[relay.Node | None] = relay.node() fruits: relay.ListConnection[Fruit] = relay.connection(resolver=fruits_resolver) fruits_lazy: relay.ListConnection[ Annotated["Fruit", strawberry.lazy("tests.relay.schema")] @@ -226,7 +225,7 @@ class Query: def fruits_concrete_resolver( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[Fruit]: # This is mimicing integrations, like Django return [ @@ -246,7 +245,7 @@ def fruits_concrete_resolver( def fruits_custom_resolver( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[Fruit]: return [ f @@ -258,7 +257,7 @@ def fruits_custom_resolver( def fruits_custom_resolver_lazy( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[Annotated["Fruit", strawberry.lazy("tests.relay.schema")]]: return [ f @@ -270,7 +269,7 @@ def fruits_custom_resolver_lazy( def fruits_custom_resolver_iterator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> Iterator[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -280,7 +279,7 @@ def fruits_custom_resolver_iterator( def fruits_custom_resolver_iterable( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> Iterator[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -290,7 +289,7 @@ def fruits_custom_resolver_iterable( def fruits_custom_resolver_generator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> Generator[Fruit, None, None]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -300,7 +299,7 @@ def fruits_custom_resolver_generator( async def fruits_custom_resolver_async_iterable( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> AsyncIterable[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -310,7 +309,7 @@ async def fruits_custom_resolver_async_iterable( async def fruits_custom_resolver_async_iterator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> AsyncIterator[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -320,7 +319,7 @@ async def fruits_custom_resolver_async_iterator( async def fruits_custom_resolver_async_generator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> AsyncGenerator[Fruit, None]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -330,7 +329,7 @@ async def fruits_custom_resolver_async_generator( def fruit_alike_connection_custom_resolver( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[FruitAlike]: return [ FruitAlike(f.id, f.name, f.color) diff --git a/tests/relay/schema_future_annotations.py b/tests/relay/schema_future_annotations.py index ab50179344..05a91be3cc 100644 --- a/tests/relay/schema_future_annotations.py +++ b/tests/relay/schema_future_annotations.py @@ -13,7 +13,6 @@ Annotated, Any, NamedTuple, - Optional, cast, ) from typing_extensions import Self @@ -38,7 +37,7 @@ def resolve_nodes( info: strawberry.Info, node_ids: Iterable[str], required: bool = False, - ) -> Iterable[Optional[Self]]: + ) -> Iterable[Self | None]: if node_ids is not None: return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] @@ -68,10 +67,10 @@ class FruitAsync(relay.Node): async def resolve_nodes( cls, *, - info: Optional[Info] = None, + info: Info | None = None, node_ids: Iterable[str], required: bool = False, - ) -> Iterable[Optional[Self]]: + ) -> Iterable[Self | None]: if node_ids is not None: return [ fruits_async[nid] if required else fruits_async.get(nid) @@ -96,12 +95,12 @@ def resolve_connection( cls, nodes: Iterable[Fruit], *, - info: Optional[Info] = None, - total_count: Optional[int] = None, - before: Optional[str] = None, - after: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, + info: Info | None = None, + total_count: int | None = None, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, **kwargs: Any, ) -> Self: edges_mapping = { @@ -202,8 +201,8 @@ class Query: permission_classes=[DummyPermission] ) nodes: list[relay.Node] = relay.node() - node_optional: Optional[relay.Node] = relay.node() - nodes_optional: list[Optional[relay.Node]] = relay.node() + node_optional: relay.Node | None = relay.node() + nodes_optional: list[relay.Node | None] = relay.node() fruits: relay.ListConnection[Fruit] = relay.connection(resolver=fruits_resolver) fruits_lazy: relay.ListConnection[ Annotated[Fruit, strawberry.lazy("tests.relay.schema")] @@ -219,7 +218,7 @@ class Query: def fruits_concrete_resolver( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[Fruit]: # This is mimicing integrations, like Django return [ @@ -239,7 +238,7 @@ def fruits_concrete_resolver( def fruits_custom_resolver( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[Fruit]: return [ f @@ -251,7 +250,7 @@ def fruits_custom_resolver( def fruits_custom_resolver_lazy( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[Annotated[Fruit, strawberry.lazy("tests.relay.schema")]]: return [ f @@ -263,7 +262,7 @@ def fruits_custom_resolver_lazy( def fruits_custom_resolver_iterator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> Iterator[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -273,7 +272,7 @@ def fruits_custom_resolver_iterator( def fruits_custom_resolver_iterable( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> Iterator[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -283,7 +282,7 @@ def fruits_custom_resolver_iterable( def fruits_custom_resolver_generator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> Generator[Fruit, None, None]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -293,7 +292,7 @@ def fruits_custom_resolver_generator( async def fruits_custom_resolver_async_iterable( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> AsyncIterable[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -303,7 +302,7 @@ async def fruits_custom_resolver_async_iterable( async def fruits_custom_resolver_async_iterator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> AsyncIterator[Fruit]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -313,7 +312,7 @@ async def fruits_custom_resolver_async_iterator( async def fruits_custom_resolver_async_generator( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> AsyncGenerator[Fruit, None]: for f in fruits.values(): if name_endswith is None or f.name.endswith(name_endswith): @@ -323,7 +322,7 @@ async def fruits_custom_resolver_async_generator( def fruit_alike_connection_custom_resolver( self, info: strawberry.Info, - name_endswith: Optional[str] = None, + name_endswith: str | None = None, ) -> list[FruitAlike]: return [ FruitAlike(f.id, f.name, f.color) diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py index feaa47a9a3..73ccf458d2 100644 --- a/tests/relay/test_connection.py +++ b/tests/relay/test_connection.py @@ -31,13 +31,13 @@ def resolve_connection( nodes: Iterable[User], *, info: Any, - after: Optional[str] = None, - before: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - max_results: Optional[int] = None, + after: str | None = None, + before: str | None = None, + first: int | None = None, + last: int | None = None, + max_results: int | None = None, **kwargs: Any, - ) -> Optional[Self]: + ) -> Self | None: return None @@ -49,13 +49,13 @@ def resolve_connection( nodes: Iterable[User], *, info: Any, - after: Optional[str] = None, - before: Optional[str] = None, - first: Optional[int] = None, - last: Optional[int] = None, - max_results: Optional[int] = None, + after: str | None = None, + before: str | None = None, + first: int | None = None, + last: int | None = None, + max_results: int | None = None, **kwargs: Any, - ) -> Optional[Self]: + ) -> Self | None: user_node_id = to_base64(User, "1") return cls( page_info=PageInfo( @@ -79,7 +79,7 @@ def test_nullable_connection_with_optional(): @strawberry.type class Query: @strawberry.relay.connection(Optional[EmptyUserConnection]) - def users(self) -> Optional[list[User]]: + def users(self) -> list[User] | None: return None schema = strawberry.Schema(query=Query) @@ -110,7 +110,7 @@ class Query: ] ] ) - def users(self) -> Optional[list[User]]: + def users(self) -> list[User] | None: return None schema = strawberry.Schema(query=Query) @@ -142,7 +142,7 @@ class Query: ] ] ) - def users(self) -> Optional[list[User]]: + def users(self) -> list[User] | None: return None schema = strawberry.Schema(query=Query) @@ -194,7 +194,7 @@ class Query: @strawberry.relay.connection( Optional[EmptyUserConnection], permission_classes=[TestPermission] ) - def users(self) -> Optional[list[User]]: # pragma: no cover + def users(self) -> list[User] | None: # pragma: no cover pytest.fail("Should not have been called...") schema = strawberry.Schema(query=Query) @@ -229,7 +229,7 @@ def users(self) -> Optional[list[User]]: # pragma: no cover ], ) def test_max_results( - field_max_results: Optional[int], + field_max_results: int | None, schema_max_results: int, results: int, expected: int, diff --git a/tests/relay/test_fields.py b/tests/relay/test_fields.py index 4f0b5108b3..5279304601 100644 --- a/tests/relay/test_fields.py +++ b/tests/relay/test_fields.py @@ -1,7 +1,6 @@ import dataclasses import textwrap from collections.abc import Iterable -from typing import Optional, Union from typing_extensions import Self import pytest @@ -1637,10 +1636,10 @@ class Fruit(relay.Node): def resolve_nodes( cls, *, - info: Optional[strawberry.Info] = None, + info: strawberry.Info | None = None, node_ids: Iterable[str], required: bool = False, - ) -> Iterable[Optional[Union[Self, FruitModel]]]: + ) -> Iterable[Self | FruitModel | None]: return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] @strawberry.type @@ -1652,10 +1651,10 @@ class PublicFruit(relay.Node): def resolve_nodes( cls, *, - info: Optional[strawberry.Info] = None, + info: strawberry.Info | None = None, node_ids: Iterable[str], required: bool = False, - ) -> Iterable[Optional[Union[Self, FruitModel]]]: + ) -> Iterable[Self | FruitModel | None]: return [fruits[nid] if required else fruits.get(nid) for nid in node_ids] @strawberry.type diff --git a/tests/schema/extensions/schema_extensions/test_extensions.py b/tests/schema/extensions/schema_extensions/test_extensions.py index d2636079a2..08b077e34b 100644 --- a/tests/schema/extensions/schema_extensions/test_extensions.py +++ b/tests/schema/extensions/schema_extensions/test_extensions.py @@ -1,7 +1,7 @@ import contextlib import json import warnings -from typing import Any, Optional +from typing import Any from unittest.mock import patch import pytest @@ -982,7 +982,7 @@ def on_execute(self): @strawberry.type class Query: @strawberry.field - def ping(self, return_value: Optional[str] = None) -> str: + def ping(self, return_value: str | None = None) -> str: if return_value is not None: return return_value return "pong" diff --git a/tests/schema/extensions/test_field_extensions.py b/tests/schema/extensions/test_field_extensions.py index ed952ba431..4266e0fef9 100644 --- a/tests/schema/extensions/test_field_extensions.py +++ b/tests/schema/extensions/test_field_extensions.py @@ -1,6 +1,6 @@ import re from collections.abc import Callable -from typing import Annotated, Any, Optional +from typing import Annotated, Any import pytest @@ -360,7 +360,7 @@ class Query: def string( self, some_input: Annotated[str, strawberry.argument(metadata={"test": "foo"})], - another_input: Optional[str] = None, + another_input: str | None = None, ) -> str: return f"This is a test!! {some_input}" diff --git a/tests/schema/extensions/test_max_aliases.py b/tests/schema/extensions/test_max_aliases.py index 0ee3c69f0c..0fd536b158 100644 --- a/tests/schema/extensions/test_max_aliases.py +++ b/tests/schema/extensions/test_max_aliases.py @@ -1,5 +1,3 @@ -from typing import Optional - import strawberry from strawberry.extensions.max_aliases import MaxAliasesLimiter @@ -13,7 +11,7 @@ class Human: @strawberry.type class Query: @strawberry.field - def user(self, name: Optional[str] = None, email: Optional[str] = None) -> Human: + def user(self, name: str | None = None, email: str | None = None) -> Human: return Human(name="Jane Doe", email="jane@example.com") version: str diff --git a/tests/schema/extensions/test_max_tokens.py b/tests/schema/extensions/test_max_tokens.py index bd877dbfaf..0a3a58f32f 100644 --- a/tests/schema/extensions/test_max_tokens.py +++ b/tests/schema/extensions/test_max_tokens.py @@ -1,5 +1,3 @@ -from typing import Optional - import strawberry from strawberry.extensions.max_tokens import MaxTokensLimiter @@ -13,7 +11,7 @@ class Human: @strawberry.type class Query: @strawberry.field - def user(self, name: Optional[str] = None, email: Optional[str] = None) -> Human: + def user(self, name: str | None = None, email: str | None = None) -> Human: return Human(name="Jane Doe", email="jane@example.com") version: str diff --git a/tests/schema/extensions/test_query_depth_limiter.py b/tests/schema/extensions/test_query_depth_limiter.py index d8ca4526ef..b019b80b1d 100644 --- a/tests/schema/extensions/test_query_depth_limiter.py +++ b/tests/schema/extensions/test_query_depth_limiter.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import pytest from graphql import ( GraphQLError, @@ -61,15 +59,15 @@ class Query: @strawberry.field def user( self, - name: Optional[str], - id: Optional[int], - age: Optional[float], - is_cool: Optional[bool], + name: str | None, + id: int | None, + age: float | None, + is_cool: bool | None, ) -> Human: pass @strawberry.field - def users(self, names: Optional[list[str]]) -> list[Human]: + def users(self, names: list[str] | None) -> list[Human]: pass @strawberry.field @@ -87,7 +85,7 @@ def cat(bio: Biography) -> Cat: def run_query( query: str, max_depth: int, should_ignore: ShouldIgnoreType = None -) -> tuple[list[GraphQLError], Union[dict[str, int], None]]: +) -> tuple[list[GraphQLError], dict[str, int] | None]: document = parse(query) result = None diff --git a/tests/schema/test_annotated/type_a.py b/tests/schema/test_annotated/type_a.py index a57c145fe7..31e2615ba1 100644 --- a/tests/schema/test_annotated/type_a.py +++ b/tests/schema/test_annotated/type_a.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, Optional +from typing import Annotated from uuid import UUID import strawberry @@ -13,4 +13,4 @@ def get_testing( self, info: strawberry.Info, id_: Annotated[UUID, strawberry.argument(name="id")], - ) -> Optional[str]: ... + ) -> str | None: ... diff --git a/tests/schema/test_annotated/type_b.py b/tests/schema/test_annotated/type_b.py index 4ecbb6efa7..2bd8bc13b0 100644 --- a/tests/schema/test_annotated/type_b.py +++ b/tests/schema/test_annotated/type_b.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, Optional +from typing import Annotated from uuid import UUID import strawberry @@ -13,4 +13,4 @@ def get_testing( self, id_: Annotated[UUID, strawberry.argument(name="id")], info: strawberry.Info, - ) -> Optional[str]: ... + ) -> str | None: ... diff --git a/tests/schema/test_arguments.py b/tests/schema/test_arguments.py index 478c3723de..87d296405c 100644 --- a/tests/schema/test_arguments.py +++ b/tests/schema/test_arguments.py @@ -1,6 +1,6 @@ import textwrap from textwrap import dedent -from typing import Annotated, Optional +from typing import Annotated import strawberry from strawberry.types.unset import UNSET @@ -83,7 +83,7 @@ def test_argument_with_default_value_none(): @strawberry.type class Query: @strawberry.field - def hello(self, name: Optional[str] = None) -> str: + def hello(self, name: str | None = None) -> str: return f"Hi {name}" schema = strawberry.Schema(query=Query) @@ -100,7 +100,7 @@ def test_optional_argument_unset(): @strawberry.type class Query: @strawberry.field - def hello(self, name: Optional[str] = UNSET, age: Optional[int] = UNSET) -> str: + def hello(self, name: str | None = UNSET, age: int | None = UNSET) -> str: if name is UNSET: return "Hi there" return f"Hi {name}" @@ -128,8 +128,8 @@ def hello(self, name: Optional[str] = UNSET, age: Optional[int] = UNSET) -> str: def test_optional_input_field_unset(): @strawberry.input class TestInput: - name: Optional[str] = UNSET - age: Optional[int] = UNSET + name: str | None = UNSET + age: int | None = UNSET @strawberry.type class Query: diff --git a/tests/schema/test_basic.py b/tests/schema/test_basic.py index 77cbf5605b..8bc3a34b91 100644 --- a/tests/schema/test_basic.py +++ b/tests/schema/test_basic.py @@ -1,8 +1,6 @@ import textwrap -import typing from dataclasses import InitVar, dataclass from enum import Enum -from typing import Optional import pytest @@ -44,7 +42,7 @@ class Query: def test_basic_schema_optional(): @strawberry.type class Query: - example: typing.Optional[str] = None + example: str | None = None schema = strawberry.Schema(query=Query) @@ -63,7 +61,7 @@ class User: @strawberry.type class Query: - user: typing.Optional[User] = None + user: User | None = None schema = strawberry.Schema(query=Query) @@ -97,7 +95,7 @@ def hello_world(self, query_param: str) -> str: def test_can_rename_fields(): @strawberry.type class Hello: - value: typing.Optional[str] = strawberry.field(name="name") + value: str | None = strawberry.field(name="name") @strawberry.type class Query: @@ -446,8 +444,8 @@ class User: @strawberry.type class Query: - me: Optional[User] = None - you: Optional[User] = None + me: User | None = None + you: User | None = None schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_directives.py b/tests/schema/test_directives.py index e5a83bc88c..c0c9a191b8 100644 --- a/tests/schema/test_directives.py +++ b/tests/schema/test_directives.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Any, NoReturn, Optional +from typing import Any, NoReturn import pytest @@ -480,7 +480,7 @@ def test_directive_value(): @strawberry.type class Cake: - frosting: Optional[str] = None + frosting: str | None = None flavor: str = "Chocolate" @strawberry.type @@ -533,7 +533,7 @@ def cake(self) -> "Cake": @strawberry.type class Cake: - frosting: Optional[str] = None + frosting: str | None = None flavor: str = "Chocolate" diff --git a/tests/schema/test_enum.py b/tests/schema/test_enum.py index b3905959a3..41cee7961f 100644 --- a/tests/schema/test_enum.py +++ b/tests/schema/test_enum.py @@ -1,7 +1,6 @@ -import typing from enum import Enum from textwrap import dedent -from typing import Annotated, Optional +from typing import Annotated import pytest @@ -129,7 +128,7 @@ class IceCreamFlavour(Enum): @strawberry.input class Input: flavour: IceCreamFlavour - optional_flavour: typing.Optional[IceCreamFlavour] = None + optional_flavour: IceCreamFlavour | None = None @strawberry.type class Query: @@ -187,7 +186,7 @@ class IceCreamFlavour(Enum): @strawberry.type class Query: @strawberry.field - def best_flavours(self) -> Optional[list[IceCreamFlavour]]: + def best_flavours(self) -> list[IceCreamFlavour] | None: return None schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_execution.py b/tests/schema/test_execution.py index f86503f452..b38efdcd5a 100644 --- a/tests/schema/test_execution.py +++ b/tests/schema/test_execution.py @@ -1,6 +1,5 @@ import textwrap from textwrap import dedent -from typing import Optional from unittest.mock import patch import pytest @@ -15,7 +14,7 @@ def test_enabling_query_validation_sync(mock_validate, validate_queries): @strawberry.type class Query: - example: Optional[str] = None + example: str | None = None extensions = [] if validate_queries is False: @@ -47,7 +46,7 @@ class Query: async def test_enabling_query_validation(validate_queries): @strawberry.type class Query: - example: Optional[str] = None + example: str | None = None extensions = [] if validate_queries is False: @@ -79,7 +78,7 @@ class Query: async def test_invalid_query_with_validation_enabled(): @strawberry.type class Query: - example: Optional[str] = None + example: str | None = None schema = strawberry.Schema(query=Query) @@ -103,7 +102,7 @@ class Query: async def test_asking_for_wrong_field(): @strawberry.type class Query: - example: Optional[str] = None + example: str | None = None schema = strawberry.Schema(query=Query, extensions=[DisableValidation()]) @@ -420,8 +419,8 @@ def process_errors(self, errors, execution_context): def test_adding_custom_validation_rules(): @strawberry.type class Query: - example: Optional[str] = None - another_example: Optional[str] = None + example: str | None = None + another_example: str | None = None class CustomRule(ValidationRule): def enter_field(self, node, *args: str) -> None: @@ -457,7 +456,7 @@ def example(self) -> str: return "hi" @strawberry.field - def this_fails(self) -> Optional[str]: + def this_fails(self) -> str | None: raise ValueError("this field fails") schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_extensions.py b/tests/schema/test_extensions.py index 9254ce3273..8b56d58a8e 100644 --- a/tests/schema/test_extensions.py +++ b/tests/schema/test_extensions.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from typing import Annotated, Union, cast +from typing import Annotated, cast from graphql import ( DirectiveLocation, @@ -139,7 +139,7 @@ class JsonThing: class StrThing: value: str - SomeThing = Annotated[Union[JsonThing, StrThing], strawberry.union("SomeThing")] + SomeThing = Annotated[JsonThing | StrThing, strawberry.union("SomeThing")] @strawberry.type() class Query: diff --git a/tests/schema/test_generics.py b/tests/schema/test_generics.py index eb547312d2..c4a42e73f1 100644 --- a/tests/schema/test_generics.py +++ b/tests/schema/test_generics.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, TypeVar from typing_extensions import Self import strawberry @@ -291,7 +291,7 @@ class User: @strawberry.type class Edge(Generic[T]): - node: Optional[T] = None + node: T | None = None @strawberry.type class Query: @@ -359,7 +359,7 @@ class User: @strawberry.type class Edge(Generic[T]): - nodes: list[Optional[T]] + nodes: list[T | None] @strawberry.type class Query: @@ -453,7 +453,7 @@ class Fallback: @strawberry.type class Query: @strawberry.field - def example(self) -> Union[Fallback, Edge[int]]: + def example(self) -> Fallback | Edge[int]: return Edge(cursor=strawberry.ID("1"), node=1) schema = strawberry.Schema(query=Query) @@ -496,7 +496,7 @@ class Codes(Enum): @strawberry.type class Query: @strawberry.field - def result(self) -> Union[Pet, ErrorNode[Codes]]: + def result(self) -> Pet | ErrorNode[Codes]: return ErrorNode(code=Codes.a) schema = strawberry.Schema(query=Query) @@ -532,7 +532,7 @@ class EstimatedValue(Generic[T]): @strawberry.type class Query: @strawberry.field - def estimated_value(self) -> Optional[EstimatedValue[int]]: + def estimated_value(self) -> EstimatedValue[int] | None: return EstimatedValue(value=1, type=EstimatedValueEnum.test) schema = strawberry.Schema(query=Query) @@ -573,7 +573,7 @@ class Fallback: @strawberry.type class Query: @strawberry.field - def example(self) -> Union[Fallback, Edge[int, str]]: + def example(self) -> Fallback | Edge[int, str]: return Edge(node="string", info=1) schema = strawberry.Schema(query=Query) @@ -608,7 +608,7 @@ class Edge(Generic[T]): @strawberry.type class Query: @strawberry.field - def example(self) -> list[Union[Edge[int], Edge[str]]]: + def example(self) -> list[Edge[int] | Edge[str]]: return [ Edge(cursor=strawberry.ID("1"), node=1), Edge(cursor=strawberry.ID("2"), node="string"), @@ -685,7 +685,7 @@ class Entity2: @strawberry.type class Query: - entities: Connection[Union[Entity1, Entity2]] + entities: Connection[Entity1 | Entity2] schema = strawberry.Schema(query=Query) @@ -777,7 +777,7 @@ class Edge(Generic[T]): @strawberry.type class Query: @strawberry.field - def user(self) -> Union[User, Edge[User]]: + def user(self) -> User | Edge[User]: return Edge(nodes=[User(name="P")]) schema = strawberry.Schema(query=Query) @@ -814,7 +814,7 @@ class Edge(Generic[T]): @strawberry.type class Query: @strawberry.field - def user(self) -> Union[User, Edge[User]]: + def user(self) -> User | Edge[User]: return Edge(nodes=[]) schema = strawberry.Schema(query=Query) @@ -851,7 +851,7 @@ class Edge(Generic[T]): @strawberry.type class Query: @strawberry.field - def user(self) -> Union[User, Edge[User]]: + def user(self) -> User | Edge[User]: return Edge(nodes=["bad example"]) # type: ignore schema = strawberry.Schema(query=Query) @@ -985,7 +985,7 @@ def test_generic_extending_with_type_var(): class Node(Generic[T]): id: strawberry.ID - def _resolve(self) -> Optional[T]: + def _resolve(self) -> T | None: return None @strawberry.type @@ -1021,7 +1021,7 @@ def books(self) -> list[Book]: def test_self(): @strawberry.interface class INode: - field: Optional[Self] + field: Self | None fields: list[Self] @strawberry.type @@ -1197,9 +1197,7 @@ class TestError: @strawberry.type class Query: @strawberry.field - def hello( - self, info: strawberry.Info - ) -> Union[Pagination[TestInterface], TestError]: + def hello(self, info: strawberry.Info) -> Pagination[TestInterface] | TestError: return Pagination(items=[Test1(data="test1")]) schema = strawberry.Schema(Query, types=[Test1]) diff --git a/tests/schema/test_generics_nested.py b/tests/schema/test_generics_nested.py index 215fad5ea6..54962913f8 100644 --- a/tests/schema/test_generics_nested.py +++ b/tests/schema/test_generics_nested.py @@ -1,5 +1,5 @@ import textwrap -from typing import Generic, Optional, TypeVar, Union +from typing import Generic, TypeVar import strawberry from strawberry.scalars import JSON @@ -68,7 +68,7 @@ class Query: @strawberry.field def blocks( self, - ) -> list[Union[BlockRowtype[int], BlockRowtype[str], JsonBlock]]: + ) -> list[BlockRowtype[int] | BlockRowtype[str] | JsonBlock]: return [ BlockRowtype(total=3, items=["a", "b", "c"]), BlockRowtype(total=1, items=[1, 2, 3, 4]), @@ -122,7 +122,7 @@ class Query: @strawberry.field def blocks( self, - ) -> list[Union[BlockRowtype[int], BlockRowtype[str], JsonBlock]]: + ) -> list[BlockRowtype[int] | BlockRowtype[str] | JsonBlock]: return [ BlockRowtype(total=3, items=[]), BlockRowtype(total=1, items=[]), @@ -176,7 +176,7 @@ class Query: @strawberry.field def blocks( self, - ) -> list[Union[BlockRowtype[int], BlockRowtype[str], JsonBlock]]: + ) -> list[BlockRowtype[int] | BlockRowtype[str] | JsonBlock]: return [ BlockRowtype(total=3, items=[["a", "b", "c"]]), BlockRowtype(total=1, items=[[1, 2, 3, 4]]), @@ -219,7 +219,7 @@ def test_using_generics_with_an_interface(): @strawberry.interface class BlockInterface: id: strawberry.ID - disclaimer: Optional[str] = strawberry.field(default=None) + disclaimer: str | None = strawberry.field(default=None) @strawberry.type class JsonBlock(BlockInterface): @@ -335,7 +335,7 @@ class Fallback: @strawberry.type class Query: @strawberry.field - def users(self) -> Union[Connection[User], Fallback]: + def users(self) -> Connection[User] | Fallback: return Connection(edge=Edge(node=User(name="Patrick"))) schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_info.py b/tests/schema/test_info.py index d4cdd18955..4ea55320ab 100644 --- a/tests/schema/test_info.py +++ b/tests/schema/test_info.py @@ -177,7 +177,7 @@ def test_info_arguments(): @strawberry.input class TestInput: name: str - age: Optional[int] = UNSET + age: int | None = UNSET selected_fields = None @@ -256,7 +256,7 @@ class Result: class Query: @strawberry.field def hello( - self, info: strawberry.Info[str, str], optional_input: Optional[str] = "hi" + self, info: strawberry.Info[str, str], optional_input: str | None = "hi" ) -> Result: nonlocal selected_fields selected_fields = info.selected_fields @@ -372,7 +372,7 @@ def field( self, info: strawberry.Info, arg_1: Annotated[str, strawberry.argument(description="Some description")], - arg_2: Optional[TestInput] = None, + arg_2: TestInput | None = None, ) -> str: nonlocal arg_1_def, arg_2_def, missing_arg_def arg_1_def = info.get_argument_definition("arg_1") diff --git a/tests/schema/test_input.py b/tests/schema/test_input.py index f8f1a7c8b6..286b67d87d 100644 --- a/tests/schema/test_input.py +++ b/tests/schema/test_input.py @@ -1,6 +1,5 @@ import re import textwrap -from typing import Optional import pytest @@ -13,7 +12,7 @@ def test_renaming_input_fields(): @strawberry.input class FilterInput: - in_: Optional[str] = strawberry.field(name="in", default=strawberry.UNSET) + in_: str | None = strawberry.field(name="in", default=strawberry.UNSET) @strawberry.type class Query: @@ -41,7 +40,7 @@ def test_input_with_nonscalar_field_default(): @strawberry.input class NonScalarField: id: int = 10 - nullable_field: Optional[int] = None + nullable_field: int | None = None @strawberry.input class Input: @@ -54,7 +53,7 @@ class Input: class ExampleOutput: input_id: int non_scalar_id: int - non_scalar_nullable_field: Optional[int] + non_scalar_nullable_field: int | None @strawberry.type class Query: diff --git a/tests/schema/test_lazy/test_lazy_generic.py b/tests/schema/test_lazy/test_lazy_generic.py index 1f1f62091e..b63985f3ae 100644 --- a/tests/schema/test_lazy/test_lazy_generic.py +++ b/tests/schema/test_lazy/test_lazy_generic.py @@ -5,7 +5,7 @@ import textwrap from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Annotated, Generic, TypeVar import pytest @@ -142,12 +142,11 @@ def test_lazy_types_declared_within_optional(): @strawberry.type class Query: - normal_edges: list[Edge[Optional[TypeC]]] + normal_edges: list[Edge[TypeC | None]] lazy_edges: list[ Edge[ - Optional[ - Annotated["TypeC", strawberry.lazy("tests.schema.test_lazy.type_c")] - ] + Annotated["TypeC", strawberry.lazy("tests.schema.test_lazy.type_c")] + | None ] ] diff --git a/tests/schema/test_lazy/type_a.py b/tests/schema/test_lazy/type_a.py index 4042fca1dd..7904963bf6 100644 --- a/tests/schema/test_lazy/type_a.py +++ b/tests/schema/test_lazy/type_a.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Annotated, Optional +from typing import TYPE_CHECKING, Annotated import strawberry @@ -8,9 +8,10 @@ @strawberry.type class TypeA: - list_of_b: Optional[ + list_of_b: ( list[Annotated["TypeB", strawberry.lazy("tests.schema.test_lazy.type_b")]] - ] = None + | None + ) = None @strawberry.field def type_b(self) -> Annotated["TypeB", strawberry.lazy(".type_b")]: diff --git a/tests/schema/test_lazy_types/test_lazy_unions.py b/tests/schema/test_lazy_types/test_lazy_unions.py index e942143430..3bacdb82fd 100644 --- a/tests/schema/test_lazy_types/test_lazy_unions.py +++ b/tests/schema/test_lazy_types/test_lazy_unions.py @@ -1,5 +1,5 @@ import textwrap -from typing import Annotated, Union +from typing import Annotated import strawberry from strawberry.printer import print_schema @@ -15,9 +15,7 @@ class TypeB: b: int -ABUnion = Annotated[ - Union[TypeA, TypeB], strawberry.union("ABUnion", types=[TypeA, TypeB]) -] +ABUnion = Annotated[TypeA | TypeB, strawberry.union("ABUnion", types=[TypeA, TypeB])] TypeALazy = Annotated[ @@ -27,10 +25,7 @@ class TypeB: "TypeB", strawberry.lazy("tests.schema.test_lazy_types.test_lazy_unions") ] LazyABUnion = Annotated[ - Union[ - TypeALazy, - TypeBLazy, - ], + TypeALazy | TypeBLazy, strawberry.union("LazyABUnion", types=[TypeALazy, TypeBLazy]), ] diff --git a/tests/schema/test_lazy_types/type_a.py b/tests/schema/test_lazy_types/type_a.py index afcba30b30..ceb3bee24c 100644 --- a/tests/schema/test_lazy_types/type_a.py +++ b/tests/schema/test_lazy_types/type_a.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import strawberry @@ -10,9 +10,9 @@ @strawberry.type class TypeA: - list_of_b: Optional[ - list[strawberry.LazyType["TypeB", "tests.schema.test_lazy_types.type_b"]] - ] = None + list_of_b: ( + list[strawberry.LazyType["TypeB", "tests.schema.test_lazy_types.type_b"]] | None + ) = None @strawberry.field def type_b(self) -> strawberry.LazyType["TypeB", ".type_b"]: # noqa: F722 diff --git a/tests/schema/test_list.py b/tests/schema/test_list.py index 4119841d55..9321c3a7fe 100644 --- a/tests/schema/test_list.py +++ b/tests/schema/test_list.py @@ -1,5 +1,3 @@ -from typing import Optional - import strawberry @@ -24,7 +22,7 @@ def test_of_optional(): @strawberry.type class Query: @strawberry.field - def example(self) -> list[Optional[str]]: + def example(self) -> list[str | None]: return ["Example", None] schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_maybe.py b/tests/schema/test_maybe.py index 2a82fd7f6d..66cb31ceb9 100644 --- a/tests/schema/test_maybe.py +++ b/tests/schema/test_maybe.py @@ -1,5 +1,4 @@ from textwrap import dedent -from typing import Optional, Union import pytest @@ -11,7 +10,7 @@ def maybe_schema() -> strawberry.Schema: @strawberry.type class User: name: str - phone: Optional[str] + phone: str | None user = User(name="Patrick", phone=None) @@ -23,7 +22,7 @@ def user(self) -> User: @strawberry.input class UpdateUserInput: - phone: strawberry.Maybe[Union[str, None]] + phone: strawberry.Maybe[str | None] @strawberry.type class Mutation: @@ -45,7 +44,7 @@ def update_user(self, input: UpdateUserInput) -> User: """ -def set_phone(schema: strawberry.Schema, phone: Optional[str]) -> dict: +def set_phone(schema: strawberry.Schema, phone: str | None) -> dict: query = """ mutation ($phone: String) { updateUser(input: { phone: $phone }) { @@ -107,7 +106,7 @@ def test_optional_argument_maybe() -> None: @strawberry.type class Query: @strawberry.field - def hello(self, name: strawberry.Maybe[Union[str, None]] = None) -> str: + def hello(self, name: strawberry.Maybe[str | None] = None) -> str: if name: return "None" if name.value is None else name.value @@ -154,7 +153,7 @@ def hello(self, name: strawberry.Maybe[Union[str, None]] = None) -> str: def test_maybe_list(): @strawberry.input class InputData: - items: strawberry.Maybe[Union[list[str], None]] + items: strawberry.Maybe[list[str] | None] @strawberry.type class Query: @@ -306,7 +305,7 @@ def test_maybe_str_error_messages(): @strawberry.input class UpdateInput: name: strawberry.Maybe[str] # Rejects null at Python validation level - phone: strawberry.Maybe[Union[str, None]] # Can accept null + phone: strawberry.Maybe[str | None] # Can accept null @strawberry.type class Query: @@ -348,9 +347,9 @@ class UpdateUserInput: # Should accept value or absent, reject null username: strawberry.Maybe[str] # Can accept null, value, or absent - bio: strawberry.Maybe[Union[str, None]] + bio: strawberry.Maybe[str | None] # Can accept null, value, or absent - website: strawberry.Maybe[Union[str, None]] + website: strawberry.Maybe[str | None] @strawberry.type class Query: @@ -456,7 +455,7 @@ class UpdateItemsInput: # Cannot accept null list - only valid list or absent tags: strawberry.Maybe[list[str]] # Can accept null, valid list, or absent - categories: strawberry.Maybe[Union[list[str], None]] + categories: strawberry.Maybe[list[str] | None] @strawberry.type class Query: @@ -553,7 +552,7 @@ def search( # Cannot accept null - only value or absent query: strawberry.Maybe[str] = None, # Can accept null, value, or absent - filter_by: strawberry.Maybe[Union[str, None]] = None, + filter_by: strawberry.Maybe[str | None] = None, ) -> str: result = [] @@ -639,7 +638,7 @@ def test(self, input: Input1) -> str: # Schema with Maybe[str | None] @strawberry.input class Input2: - field: strawberry.Maybe[Union[str, None]] + field: strawberry.Maybe[str | None] @strawberry.type class Query2: @@ -667,14 +666,14 @@ def test_maybe_complex_types(): class AddressInput: street: str city: str - zip_code: Optional[str] = None + zip_code: str | None = None @strawberry.input class UpdateProfileInput: # Cannot accept null address - only valid address or absent address: strawberry.Maybe[AddressInput] # Can accept null, valid address, or absent - billing_address: strawberry.Maybe[Union[AddressInput, None]] + billing_address: strawberry.Maybe[AddressInput | None] @strawberry.type class Query: @@ -763,7 +762,7 @@ def test_maybe_union_with_none_works(): @strawberry.input class TestInput: # This should work correctly - can accept value, null, or absent - field: strawberry.Maybe[Union[str, None]] + field: strawberry.Maybe[str | None] @strawberry.type class Query: @@ -812,7 +811,7 @@ class CompareInput: # Generates String (optional) but rejects null at Python level required_field: strawberry.Maybe[str] # Generates String (optional) and accepts null - optional_field: strawberry.Maybe[Union[str, None]] + optional_field: strawberry.Maybe[str | None] @strawberry.type class Query: @@ -839,7 +838,7 @@ class Input1: @strawberry.input class Input2: - field: strawberry.Maybe[Union[str, None]] + field: strawberry.Maybe[str | None] @strawberry.type class Query: @@ -904,7 +903,7 @@ class ComprehensiveInput: # String (optional) - can be value or absent, but rejects null at Python level strict_field: strawberry.Maybe[str] # String (optional) - can be null, value, or absent - flexible_field: strawberry.Maybe[Union[str, None]] + flexible_field: strawberry.Maybe[str | None] @strawberry.type class Query: diff --git a/tests/schema/test_mutation.py b/tests/schema/test_mutation.py index b41476018b..7557871e47 100644 --- a/tests/schema/test_mutation.py +++ b/tests/schema/test_mutation.py @@ -1,5 +1,4 @@ import dataclasses -import typing from textwrap import dedent import strawberry @@ -91,12 +90,12 @@ class Query: @strawberry.input class InputExample: name: str - age: typing.Optional[int] = UNSET + age: int | None = UNSET @strawberry.type class Mutation: @strawberry.mutation - def say(self, name: typing.Optional[str] = UNSET) -> str: # type: ignore + def say(self, name: str | None = UNSET) -> str: # type: ignore if name is UNSET: return "Name is unset" @@ -127,12 +126,12 @@ class Query: @strawberry.input class InputExample: first_name: str - age: typing.Optional[str] = UNSET + age: str | None = UNSET @strawberry.type class Mutation: @strawberry.mutation - def say(self, first_name: typing.Optional[str] = UNSET) -> str: # type: ignore + def say(self, first_name: str | None = UNSET) -> str: # type: ignore if first_name is UNSET: return "Name is unset" @@ -178,7 +177,7 @@ class Query: @strawberry.type class Mutation: @strawberry.mutation - def say(self, first_name: typing.Optional[str] = UNSET) -> str: # type: ignore + def say(self, first_name: str | None = UNSET) -> str: # type: ignore return f"Hello {first_name}!" schema = strawberry.Schema(query=Query, mutation=Mutation) @@ -209,7 +208,7 @@ class Query: @strawberry.input class Input: - name: typing.Optional[str] = UNSET + name: str | None = UNSET @strawberry.type class Mutation: diff --git a/tests/schema/test_name_converter.py b/tests/schema/test_name_converter.py index fdc2bba07b..4b13d994a7 100644 --- a/tests/schema/test_name_converter.py +++ b/tests/schema/test_name_converter.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Annotated, Generic, Optional, TypeVar, Union +from typing import Annotated, Generic, TypeVar import strawberry from strawberry.directive import StrawberryDirective @@ -35,7 +35,7 @@ def from_union(self, union: StrawberryUnion) -> str: def from_generic( self, generic_type: StrawberryObjectDefinition, - types: list[Union[StrawberryType, type]], + types: list[StrawberryType | type], ) -> str: return super().from_generic(generic_type, types) + self.suffix @@ -43,7 +43,7 @@ def from_interface(self, interface: StrawberryObjectDefinition) -> str: return super().from_interface(interface) + self.suffix def from_directive( - self, directive: Union[StrawberryDirective, StrawberrySchemaDirective] + self, directive: StrawberryDirective | StrawberrySchemaDirective ) -> str: return super().from_directive(directive) + self.suffix @@ -115,11 +115,11 @@ class MyGeneric(Generic[T]): @strawberry.type class Query: @strawberry.field(directives=[MyDirective(name="my-directive")]) - def user(self, input: UserInput) -> Union[User, Error]: + def user(self, input: UserInput) -> User | Error: return User(name="Patrick") enum: MyEnum = MyEnum.A - field: Optional[MyGeneric[str]] = None + field: MyGeneric[str] | None = None field_with_lazy: MyGeneric[ Annotated[ "TypeWithDifferentNameThanClass", diff --git a/tests/schema/test_one_of.py b/tests/schema/test_one_of.py index 50e2a860cf..8f73ea118e 100644 --- a/tests/schema/test_one_of.py +++ b/tests/schema/test_one_of.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any import pytest @@ -8,14 +8,14 @@ @strawberry.input(one_of=True) class ExampleInputTagged: - a: strawberry.Maybe[Union[str, None]] - b: strawberry.Maybe[Union[int, None]] + a: strawberry.Maybe[str | None] + b: strawberry.Maybe[int | None] @strawberry.type class ExampleResult: - a: Union[str, None] - b: Union[int, None] + a: str | None + b: int | None @strawberry.type @@ -210,13 +210,13 @@ def test_works_with_camelcasing(): @strawberry.input(directives=[OneOf()]) class ExampleWithLongerNames: - a_field: strawberry.Maybe[Union[str, None]] - b_field: strawberry.Maybe[Union[int, None]] + a_field: strawberry.Maybe[str | None] + b_field: strawberry.Maybe[int | None] @strawberry.type class Result: - a_field: Union[str, None] - b_field: Union[int, None] + a_field: str | None + b_field: int | None @strawberry.type class Query: diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 4cd4929414..144d549f9b 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -1,7 +1,6 @@ import re import textwrap import typing -from typing import Optional import pytest @@ -420,7 +419,7 @@ class Query: @strawberry.field( extensions=[PermissionExtension([IsAuthorized()], fail_silently=True)] ) - def name(self) -> Optional[str]: # pragma: no cover + def name(self) -> str | None: # pragma: no cover return "ABC" schema = strawberry.Schema(query=Query) @@ -443,7 +442,7 @@ class Query: @strawberry.field( extensions=[PermissionExtension([IsAuthorized()], fail_silently=True)] ) - def names(self) -> Optional[list[str]]: # pragma: no cover + def names(self) -> list[str] | None: # pragma: no cover return ["ABC"] schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_resolvers.py b/tests/schema/test_resolvers.py index 1ab97cea3f..ab4c771066 100644 --- a/tests/schema/test_resolvers.py +++ b/tests/schema/test_resolvers.py @@ -1,7 +1,6 @@ # type: ignore -import typing from contextlib import nullcontext -from typing import Any, Generic, NamedTuple, Optional, TypeVar, Union +from typing import Any, Generic, NamedTuple, TypeVar import pytest @@ -220,7 +219,7 @@ def get_users(cls) -> "list[User]": @strawberry.type class Query: - users: typing.List[User] = strawberry.field(resolver=User.get_users) + users: list[User] = strawberry.field(resolver=User.get_users) schema = strawberry.Schema(query=Query) @@ -391,12 +390,12 @@ class AType: T = TypeVar("T") - def resolver() -> Optional[T]: + def resolver() -> T | None: return AType(some=1) @strawberry.type class Query: - a_type: Optional[AType] = strawberry.field(resolver) + a_type: AType | None = strawberry.field(resolver) strawberry.Schema(query=Query) @@ -451,12 +450,12 @@ class AType: class OtherType: other: int - def resolver() -> Union[T, OtherType]: + def resolver() -> T | OtherType: return AType(some=1) @strawberry.type class Query: - union_type: Union[AType, OtherType] = strawberry.field(resolver) + union_type: AType | OtherType = strawberry.field(resolver) strawberry.Schema(query=Query) diff --git a/tests/schema/test_scalars.py b/tests/schema/test_scalars.py index 9995eb6d09..13363a2b29 100644 --- a/tests/schema/test_scalars.py +++ b/tests/schema/test_scalars.py @@ -1,7 +1,6 @@ from datetime import date, datetime, timedelta, timezone from decimal import Decimal from textwrap import dedent -from typing import Optional from uuid import UUID import pytest @@ -167,7 +166,7 @@ def echo_json(data: JSON) -> JSON: return data @strawberry.field - def echo_json_nullable(data: Optional[JSON]) -> Optional[JSON]: + def echo_json_nullable(data: JSON | None) -> JSON | None: return data schema = strawberry.Schema(query=Query) diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index a545521a11..c7cc6d95ff 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -6,7 +6,6 @@ from typing import ( Annotated, Any, - Union, ) import pytest @@ -151,7 +150,7 @@ class Query: @strawberry.type class Subscription: @strawberry.subscription - async def example_with_union(self) -> AsyncGenerator[Union[A, B], None]: + async def example_with_union(self) -> AsyncGenerator[A | B, None]: yield A(a="Hi") schema = strawberry.Schema(query=Query, subscription=Subscription) @@ -188,9 +187,7 @@ class Subscription: @strawberry.subscription async def example_with_annotated_union( self, - ) -> AsyncGenerator[ - Annotated[Union[C, D], strawberry.union("UnionName")], None - ]: + ) -> AsyncGenerator[Annotated[C | D, strawberry.union("UnionName")], None]: yield C(c="Hi") schema = strawberry.Schema(query=Query, subscription=Subscription) diff --git a/tests/schema/test_union.py b/tests/schema/test_union.py index 05d76219da..2f0cd9ef2b 100644 --- a/tests/schema/test_union.py +++ b/tests/schema/test_union.py @@ -1,7 +1,7 @@ import textwrap from dataclasses import dataclass from textwrap import dedent -from typing import Annotated, Generic, Optional, TypeVar, Union +from typing import Annotated, Generic, TypeVar, Union import pytest @@ -21,7 +21,7 @@ class B: @strawberry.type class Query: - ab: Union[A, B] = strawberry.field(default_factory=lambda: A(a=5)) + ab: A | B = strawberry.field(default_factory=lambda: A(a=5)) schema = strawberry.Schema(query=Query) query = """{ @@ -51,7 +51,7 @@ class B: @strawberry.type class Query: - ab: Union[A, B] = strawberry.field(default_factory=lambda: B(b=5)) + ab: A | B = strawberry.field(default_factory=lambda: B(b=5)) schema = strawberry.Schema(query=Query) query = """{ @@ -81,7 +81,7 @@ class B: @strawberry.type class Query: - ab: Union[A, B] = "ciao" + ab: A | B = "ciao" schema = strawberry.Schema(query=Query) query = """{ @@ -115,7 +115,7 @@ class B: @strawberry.type class Mutation: @strawberry.mutation - def hello(self) -> Union[A, B]: + def hello(self) -> A | B: return B(y=5) schema = strawberry.Schema(query=A, mutation=Mutation) @@ -158,7 +158,7 @@ class B: @strawberry.type class Mutation: @strawberry.mutation - def hello(self) -> Union[A, B]: + def hello(self) -> A | B: return Outside(c=5) # type:ignore schema = strawberry.Schema(query=A, mutation=Mutation, types=[Outside]) @@ -205,7 +205,7 @@ class B: @strawberry.type class Query: @strawberry.field - def hello(self) -> Union[A, B]: + def hello(self) -> A | B: return Outside(c=5) # type:ignore schema = strawberry.Schema(query=Query) @@ -234,7 +234,7 @@ class A: class B: b: int - Result = Annotated[Union[A, B], strawberry.union(name="Result")] + Result = Annotated[A | B, strawberry.union(name="Result")] @strawberry.type class Query: @@ -274,7 +274,7 @@ class B: b: int Result = Annotated[ - Union[A, B], strawberry.union(name="Result", description="Example Result") + A | B, strawberry.union(name="Result", description="Example Result") ] @strawberry.type @@ -314,11 +314,11 @@ class A: class B: b: int - Result = Annotated[Union[A, B], strawberry.union(name="Result")] + Result = Annotated[A | B, strawberry.union(name="Result")] @strawberry.type class Query: - ab: Optional[Result] = None + ab: Result | None = None schema = strawberry.Schema(query=Query) @@ -362,8 +362,8 @@ class UnionB1: class UnionB2: value: int - field1: Union[UnionA1, UnionA2] - field2: Union[UnionB1, UnionB2] + field1: UnionA1 | UnionA2 + field2: UnionB1 | UnionB2 schema = strawberry.Schema(query=CoolType) @@ -398,7 +398,7 @@ class A: class B: b: int - MyUnion = Annotated[Union[A, B], strawberry.union("MyUnion")] + MyUnion = Annotated[A | B, strawberry.union("MyUnion")] @strawberry.type class Query: @@ -443,7 +443,7 @@ def is_type_of(cls, obj, _info) -> bool: class B: b: int - MyUnion = Annotated[Union[A, B], strawberry.union("MyUnion")] + MyUnion = Annotated[A | B, strawberry.union("MyUnion")] @strawberry.type class Query: @@ -513,7 +513,7 @@ class B: @strawberry.input class Input: name: str - something: Union[A, B] + something: A | B @strawberry.type class Query: @@ -549,11 +549,11 @@ class B: @strawberry.type class Query: @strawberry.field - def container_a(self) -> Union[Container[A], A]: + def container_a(self) -> Container[A] | A: return Container(items=[A(a="hello")]) @strawberry.field - def container_b(self) -> Union[Container[B], B]: + def container_b(self) -> Container[B] | B: return Container(items=[B(b=3)]) schema = strawberry.Schema(query=Query) @@ -607,13 +607,13 @@ def test_lazy_union(): @strawberry.type class Query: @strawberry.field - def a(self) -> Union[TypeA, TypeB]: + def a(self) -> TypeA | TypeB: from tests.schema.test_lazy_types.type_a import TypeA return TypeA(list_of_b=[]) @strawberry.field - def b(self) -> Union[TypeA, TypeB]: + def b(self) -> TypeA | TypeB: from tests.schema.test_lazy_types.type_b import TypeB return TypeB() @@ -675,7 +675,7 @@ class Something: @strawberry.type class Query: - union: Union[Something, AnnotatedInt] + union: Something | AnnotatedInt strawberry.Schema(query=Query) @@ -692,7 +692,7 @@ class ICanBeInUnion: @strawberry.type class Query: - union: Union[ICanBeInUnion, int] + union: ICanBeInUnion | int strawberry.Schema(query=Query) @@ -712,7 +712,7 @@ class ICanBeInUnion: @strawberry.type class Query: - union: Union[ICanBeInUnion, list[str]] + union: ICanBeInUnion | list[str] strawberry.Schema(query=Query) @@ -732,7 +732,7 @@ class ICanBeInUnion: @strawberry.type class Query: - union: Union[ICanBeInUnion, list[str]] + union: ICanBeInUnion | list[str] strawberry.Schema(query=Query) @@ -754,7 +754,7 @@ class Always42: @strawberry.type class Query: union: Annotated[ - Union[Always42, ICanBeInUnion], strawberry.union(name="ExampleUnion") + Always42 | ICanBeInUnion, strawberry.union(name="ExampleUnion") ] strawberry.Schema(query=Query) @@ -781,8 +781,8 @@ class EvenMoreSpecificError: @strawberry.type class Query: - user: Union[User, Error] - error: Union[User, ErrorUnion] + user: User | Error + error: User | ErrorUnion schema = strawberry.Schema(query=Query) @@ -881,7 +881,7 @@ class ObjectQueries(Generic[T]): @strawberry.field def by_id( self, id: strawberry.ID - ) -> Annotated[Union[T, NotFoundError], strawberry.union("ByIdResult")]: ... + ) -> Annotated[T | NotFoundError, strawberry.union("ByIdResult")]: ... @strawberry.type class Query: @@ -937,7 +937,7 @@ class ObjectQueries(Generic[T]): @strawberry.field def by_id( self, id: strawberry.ID - ) -> Union[T, Annotated[NotFoundError, strawberry.union("ByIdResult")]]: ... + ) -> T | Annotated[NotFoundError, strawberry.union("ByIdResult")]: ... @strawberry.type class Query: @@ -995,9 +995,7 @@ class UnionObjectQueries(Generic[T, U]): @strawberry.field def by_id( self, id: strawberry.ID - ) -> Union[ - T, Annotated[Union[U, NotFoundError], strawberry.union("ByIdResult")] - ]: ... + ) -> T | Annotated[U | NotFoundError, strawberry.union("ByIdResult")]: ... @strawberry.type class Query: @@ -1102,9 +1100,9 @@ class B: class C: c: int - a = Annotated[Union[A, B], strawberry.union("AorB")] + a = Annotated[A | B, strawberry.union("AorB")] - b = Annotated[Union[B, C], strawberry.union("BorC")] + b = Annotated[B | C, strawberry.union("BorC")] c = Union[a, b] @@ -1153,11 +1151,11 @@ class B: class C: c: int - a = Annotated[Union[A, B], strawberry.union("AorB")] + a = Annotated[A | B, strawberry.union("AorB")] - b = Annotated[Union[B, C], strawberry.union("BorC")] + b = Annotated[B | C, strawberry.union("BorC")] - c = Annotated[Union[a, b], strawberry.union("ABC")] + c = Annotated[a | b, strawberry.union("ABC")] @strawberry.type class Query: @@ -1208,7 +1206,7 @@ class ProUser: class GenType(Generic[T]): data: T - GeneralUser = Annotated[Union[User, ProUser], strawberry.union("GeneralUser")] + GeneralUser = Annotated[User | ProUser, strawberry.union("GeneralUser")] @strawberry.type class Response(GenType[GeneralUser]): ... diff --git a/tests/schema/test_union_deprecated.py b/tests/schema/test_union_deprecated.py index b51ab6ae15..6678b678b8 100644 --- a/tests/schema/test_union_deprecated.py +++ b/tests/schema/test_union_deprecated.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from textwrap import dedent -from typing import Optional import pytest @@ -107,7 +106,7 @@ class B: @strawberry.type class Query: - ab: Optional[Result] = None + ab: Result | None = None schema = strawberry.Schema(query=Query) diff --git a/tests/schema/validation_rules/test_maybe_null.py b/tests/schema/validation_rules/test_maybe_null.py index 5a335ced48..74a26bf954 100644 --- a/tests/schema/validation_rules/test_maybe_null.py +++ b/tests/schema/validation_rules/test_maybe_null.py @@ -1,5 +1,3 @@ -from typing import Union - import strawberry @@ -9,7 +7,7 @@ def test_maybe_null_validation_rule_input_fields(): @strawberry.input class TestInput: strict_field: strawberry.Maybe[str] # Should reject null - flexible_field: strawberry.Maybe[Union[str, None]] # Should allow null + flexible_field: strawberry.Maybe[str | None] # Should allow null @strawberry.type class Query: @@ -67,7 +65,7 @@ class Query: def search( self, query: strawberry.Maybe[str] = None, # Should reject null - filter_by: strawberry.Maybe[Union[str, None]] = None, # Should allow null + filter_by: strawberry.Maybe[str | None] = None, # Should allow null ) -> str: return "success" @@ -113,7 +111,7 @@ def test_maybe_null_validation_rule_multiple_errors(): class TestInput: field1: strawberry.Maybe[str] field2: strawberry.Maybe[int] - field3: strawberry.Maybe[Union[str, None]] # This one allows null + field3: strawberry.Maybe[str | None] # This one allows null @strawberry.type class Query: diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 6931572847..dc94590148 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,7 +1,7 @@ import asyncio from asyncio.futures import Future from collections.abc import Awaitable, Callable -from typing import Any, Optional, Union, cast +from typing import Any, Optional, cast import pytest from pytest_mock import MockerFixture @@ -73,7 +73,7 @@ async def test_max_batch_size(mocker: MockerFixture): @pytest.mark.asyncio async def test_error(): - async def idx(keys: list[int]) -> list[Union[int, ValueError]]: + async def idx(keys: list[int]) -> list[int | ValueError]: return [ValueError()] loader = DataLoader(load_fn=idx) @@ -84,7 +84,7 @@ async def idx(keys: list[int]) -> list[Union[int, ValueError]]: @pytest.mark.asyncio async def test_error_and_values(): - async def idx(keys: list[int]) -> list[Union[int, ValueError]]: + async def idx(keys: list[int]) -> list[int | ValueError]: return [2] if keys == [2] else [ValueError()] loader = DataLoader(load_fn=idx) @@ -97,7 +97,7 @@ async def idx(keys: list[int]) -> list[Union[int, ValueError]]: @pytest.mark.asyncio async def test_when_raising_error_in_loader(): - async def idx(keys: list[int]) -> list[Union[int, ValueError]]: + async def idx(keys: list[int]) -> list[int | ValueError]: raise ValueError loader = DataLoader(load_fn=idx) @@ -195,7 +195,7 @@ async def test_cache_disabled_immediate_await(mocker: MockerFixture): @pytest.mark.asyncio async def test_prime(): - async def idx(keys: list[Union[int, float]]) -> list[Union[int, float]]: + async def idx(keys: list[int | float]) -> list[int | float]: assert keys, "At least one key must be specified" return keys @@ -241,7 +241,7 @@ async def idx(keys: list[Union[int, float]]) -> list[Union[int, float]]: @pytest.mark.asyncio async def test_prime_nocache(): - async def idx(keys: list[Union[int, float]]) -> list[Union[int, float]]: + async def idx(keys: list[int | float]) -> list[int | float]: assert keys, "At least one key must be specified" return keys diff --git a/tests/test_printer/test_basic.py b/tests/test_printer/test_basic.py index bac73e9b16..cd7d73264f 100644 --- a/tests/test_printer/test_basic.py +++ b/tests/test_printer/test_basic.py @@ -1,5 +1,4 @@ import textwrap -from typing import Optional from uuid import UUID import strawberry @@ -77,7 +76,7 @@ class Query: def test_optional(): @strawberry.type class Query: - s: Optional[str] + s: str | None expected_type = """ type Query { @@ -133,7 +132,7 @@ def search(self, input: MyInput) -> str: def test_input_defaults(): @strawberry.input class MyInput: - s: Optional[str] = None + s: str | None = None i: int = 0 b: bool = False f: float = 0.0 @@ -141,7 +140,7 @@ class MyInput: id: strawberry.ID = strawberry.ID("some_id") id_number: strawberry.ID = strawberry.ID(123) # type: ignore id_number_string: strawberry.ID = strawberry.ID("123") - x: Optional[int] = UNSET + x: int | None = UNSET l: list[str] = strawberry.field(default_factory=list) # noqa: E741 list_with_values: list[str] = strawberry.field( default_factory=lambda: ["a", "b"] diff --git a/tests/test_printer/test_schema_directives.py b/tests/test_printer/test_schema_directives.py index da86f22aa9..3bf4d45599 100644 --- a/tests/test_printer/test_schema_directives.py +++ b/tests/test_printer/test_schema_directives.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any import strawberry from strawberry import BasePermission, Info @@ -70,12 +70,12 @@ class SensitiveValue: @strawberry.schema_directive(locations=[Location.OBJECT, Location.FIELD_DEFINITION]) class SensitiveData: reason: str - meta: Optional[list[SensitiveValue]] = UNSET + meta: list[SensitiveValue] | None = UNSET @strawberry.schema_directive(locations=[Location.INPUT_OBJECT]) class SensitiveInput: reason: str - meta: Optional[list[SensitiveValue]] = UNSET + meta: list[SensitiveValue] | None = UNSET @strawberry.schema_directive(locations=[Location.INPUT_FIELD_DEFINITION]) class RangeInput: @@ -612,7 +612,7 @@ class Sensitive: reason: str MyUnion = Annotated[ - Union[A, B], + A | B, strawberry.union(name="MyUnion", directives=[Sensitive(reason="example")]), ] @@ -718,13 +718,13 @@ def hello( def test_print_directive_with_unset_value(): @strawberry.input class FooInput: - a: Optional[str] = strawberry.UNSET - b: Optional[str] = strawberry.UNSET + a: str | None = strawberry.UNSET + b: str | None = strawberry.UNSET @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) class FooDirective: input: FooInput - optional_input: Optional[FooInput] = strawberry.UNSET + optional_input: FooInput | None = strawberry.UNSET @strawberry.type class Query: @@ -761,7 +761,7 @@ class FooInput: @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) class FooDirective: input: FooInput - optional_input: Optional[FooInput] = strawberry.UNSET + optional_input: FooInput | None = strawberry.UNSET @strawberry.type class Query: diff --git a/tests/types/resolving/test_lists.py b/tests/types/resolving/test_lists.py index a09a634602..a49c3839db 100644 --- a/tests/types/resolving/test_lists.py +++ b/tests/types/resolving/test_lists.py @@ -40,36 +40,36 @@ def test_basic_sequence(): def test_list_of_optional(): - annotation = StrawberryAnnotation(list[Optional[int]]) + annotation = StrawberryAnnotation(list[int | None]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Optional[int] assert resolved == StrawberryList(of_type=Optional[int]) - assert resolved == list[Optional[int]] + assert resolved == list[int | None] def test_sequence_of_optional(): - annotation = StrawberryAnnotation(Sequence[Optional[int]]) + annotation = StrawberryAnnotation(Sequence[int | None]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Optional[int] assert resolved == StrawberryList(of_type=Optional[int]) - assert resolved == Sequence[Optional[int]] + assert resolved == Sequence[int | None] def test_tuple_of_optional(): - annotation = StrawberryAnnotation(tuple[Optional[int]]) + annotation = StrawberryAnnotation(tuple[int | None]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Optional[int] assert resolved == StrawberryList(of_type=Optional[int]) - assert resolved == tuple[Optional[int]] + assert resolved == tuple[int | None] def test_list_of_lists(): @@ -114,14 +114,14 @@ class Animal: class Fungus: spore: bool - annotation = StrawberryAnnotation(list[Union[Animal, Fungus]]) + annotation = StrawberryAnnotation(list[Animal | Fungus]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Union[Animal, Fungus] assert resolved == StrawberryList(of_type=Union[Animal, Fungus]) - assert resolved == list[Union[Animal, Fungus]] + assert resolved == list[Animal | Fungus] def test_sequence_of_union(): @@ -133,14 +133,14 @@ class Animal: class Fungus: spore: bool - annotation = StrawberryAnnotation(Sequence[Union[Animal, Fungus]]) + annotation = StrawberryAnnotation(Sequence[Animal | Fungus]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryList) assert resolved.of_type == Union[Animal, Fungus] assert resolved == StrawberryList(of_type=Union[Animal, Fungus]) - assert resolved == Sequence[Union[Animal, Fungus]] + assert resolved == Sequence[Animal | Fungus] def test_list_builtin(): diff --git a/tests/types/resolving/test_optionals.py b/tests/types/resolving/test_optionals.py index 1c2784a03a..2dbcf1e6fe 100644 --- a/tests/types/resolving/test_optionals.py +++ b/tests/types/resolving/test_optionals.py @@ -18,7 +18,7 @@ def test_basic_optional(): def test_optional_with_unset(): - annotation = StrawberryAnnotation(Union[UnsetType, Optional[str]]) + annotation = StrawberryAnnotation(Union[UnsetType, str | None]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryOptional) @@ -29,7 +29,7 @@ def test_optional_with_unset(): def test_optional_with_type_of_unset(): - annotation = StrawberryAnnotation(Union[type[strawberry.UNSET], Optional[str]]) + annotation = StrawberryAnnotation(Union[type[strawberry.UNSET], str | None]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryOptional) @@ -63,14 +63,14 @@ def test_optional_list(): def test_optional_optional(): """Optional[Optional[...]] is squashed by Python to just Optional[...]""" - annotation = StrawberryAnnotation(Optional[Optional[bool]]) + annotation = StrawberryAnnotation(Optional[bool | None]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryOptional) assert resolved.of_type is bool assert resolved == StrawberryOptional(of_type=bool) - assert resolved == Optional[Optional[bool]] + assert resolved == Optional[bool | None] assert resolved == Optional[bool] @@ -83,22 +83,22 @@ class CoolType: class UncoolType: bar: bool - annotation = StrawberryAnnotation(Optional[Union[CoolType, UncoolType]]) + annotation = StrawberryAnnotation(Optional[CoolType | UncoolType]) resolved = annotation.resolve() assert isinstance(resolved, StrawberryOptional) assert resolved.of_type == Union[CoolType, UncoolType] assert resolved == StrawberryOptional(of_type=Union[CoolType, UncoolType]) - assert resolved == Optional[Union[CoolType, UncoolType]] + assert resolved == Optional[CoolType | UncoolType] # TODO: move to a field test file def test_type_add_type_definition_with_fields(): @strawberry.type class Query: - name: Optional[str] - age: Optional[int] + name: str | None + age: int | None definition = Query.__strawberry_definition__ assert definition.name == "Query" @@ -120,8 +120,8 @@ class Query: def test_passing_custom_names_to_fields(): @strawberry.type class Query: - x: Optional[str] = strawberry.field(name="name") - y: Optional[int] = strawberry.field(name="age") + x: str | None = strawberry.field(name="name") + y: int | None = strawberry.field(name="age") definition = Query.__strawberry_definition__ assert definition.name == "Query" @@ -143,8 +143,8 @@ class Query: def test_passing_nothing_to_fields(): @strawberry.type class Query: - name: Optional[str] = strawberry.field() - age: Optional[int] = strawberry.field() + name: str | None = strawberry.field() + age: int | None = strawberry.field() definition = Query.__strawberry_definition__ assert definition.name == "Query" @@ -167,7 +167,7 @@ def test_resolver_fields(): @strawberry.type class Query: @strawberry.field - def name(self) -> Optional[str]: + def name(self) -> str | None: return "Name" definition = Query.__strawberry_definition__ @@ -186,7 +186,7 @@ def test_resolver_fields_arguments(): @strawberry.type class Query: @strawberry.field - def name(self, argument: Optional[str]) -> Optional[str]: + def name(self, argument: str | None) -> str | None: return "Name" definition = Query.__strawberry_definition__ diff --git a/tests/types/resolving/test_string_annotations.py b/tests/types/resolving/test_string_annotations.py index ce187896c7..21ea1e645b 100644 --- a/tests/types/resolving/test_string_annotations.py +++ b/tests/types/resolving/test_string_annotations.py @@ -142,8 +142,8 @@ class Query: def test_optional(): @strawberry.type class Query: - name: "Optional[str]" - age: "Optional[int]" + name: "str | None" + age: "int | None" definition = Query.__strawberry_definition__ assert definition.name == "Query" diff --git a/tests/types/resolving/test_unions.py b/tests/types/resolving/test_unions.py index 0a5d1ca030..861759c2d5 100644 --- a/tests/types/resolving/test_unions.py +++ b/tests/types/resolving/test_unions.py @@ -62,7 +62,7 @@ class User: class Error: name: str - cool_union = Annotated[Union[User, Error], union(name="CoolUnion")] + cool_union = Annotated[User | Error, union(name="CoolUnion")] annotation = StrawberryAnnotation(cool_union) resolved = annotation.resolve() @@ -86,7 +86,7 @@ class Error: class Edge(Generic[T]): node: T - Result = Annotated[Union[Error, Edge[str]], strawberry.union("Result")] + Result = Annotated[Error | Edge[str], strawberry.union("Result")] strawberry_union = StrawberryAnnotation(Result).resolve() @@ -105,12 +105,7 @@ class Edge(Generic[T]): ) def test_error_with_scalar_types(): Something = Annotated[ - Union[ - int, - str, - float, - bool, - ], + int | str | float | bool, strawberry.union("Something"), ] diff --git a/tests/types/test_annotation.py b/tests/types/test_annotation.py index 4a1ab033d9..2a20530b56 100644 --- a/tests/types/test_annotation.py +++ b/tests/types/test_annotation.py @@ -1,6 +1,6 @@ import itertools from enum import Enum -from typing import Optional, TypeVar, Union +from typing import Optional, TypeVar import pytest @@ -39,7 +39,7 @@ class NumaNuma(Enum): @pytest.mark.parametrize( ("type1", "type2"), itertools.combinations_with_replacement(types, 2) ) -def test_annotation_hash(type1: Union[object, str], type2: Union[object, str]): +def test_annotation_hash(type1: object | str, type2: object | str): annotation1 = StrawberryAnnotation(type1) annotation2 = StrawberryAnnotation(type2) assert ( diff --git a/tests/types/test_argument_types.py b/tests/types/test_argument_types.py index 800f716654..74a3137931 100644 --- a/tests/types/test_argument_types.py +++ b/tests/types/test_argument_types.py @@ -79,7 +79,7 @@ def get_id(person_input: PersonInput) -> int: def test_optional(): @strawberry.field - def set_age(age: Optional[int]) -> bool: + def set_age(age: int | None) -> bool: _ = age return True diff --git a/tests/types/test_convert_to_dictionary.py b/tests/types/test_convert_to_dictionary.py index 496cc69641..b35920b081 100644 --- a/tests/types/test_convert_to_dictionary.py +++ b/tests/types/test_convert_to_dictionary.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional import strawberry from strawberry import asdict @@ -52,7 +51,7 @@ def test_convert_input_to_dictionary(): class QnaInput: title: str description: str - tags: Optional[list[str]] = strawberry.field(default=None) + tags: list[str] | None = strawberry.field(default=None) title = "Where is the capital of United Kingdom?" description = "London is the capital of United Kingdom." diff --git a/tests/types/test_lazy_types.py b/tests/types/test_lazy_types.py index d686dfd041..a517f7ee9d 100644 --- a/tests/types/test_lazy_types.py +++ b/tests/types/test_lazy_types.py @@ -1,7 +1,7 @@ # type: ignore import enum import textwrap -from typing import Annotated, Generic, TypeAlias, TypeVar, Union +from typing import Annotated, Generic, TypeAlias, TypeVar import strawberry from strawberry.annotation import StrawberryAnnotation @@ -163,7 +163,7 @@ def test_lazy_type_in_union(): ActiveType = LazyType("LaziestType", "tests.types.test_lazy_types") ActiveEnum = LazyType("LazyEnum", "tests.types.test_lazy_types") - something = Annotated[Union[ActiveType, ActiveEnum], union(name="CoolUnion")] + something = Annotated[ActiveType | ActiveEnum, union(name="CoolUnion")] annotation = StrawberryAnnotation(something) resolved = annotation.resolve() @@ -182,7 +182,7 @@ def test_lazy_function_in_union(): ] ActiveEnum = Annotated["LazyEnum", strawberry.lazy("tests.types.test_lazy_types")] - something = Annotated[Union[ActiveType, ActiveEnum], union(name="CoolUnion")] + something = Annotated[ActiveType | ActiveEnum, union(name="CoolUnion")] annotation = StrawberryAnnotation(something) resolved = annotation.resolve() diff --git a/tests/types/test_object_types.py b/tests/types/test_object_types.py index eb9e304354..93385df7a4 100644 --- a/tests/types/test_object_types.py +++ b/tests/types/test_object_types.py @@ -2,7 +2,7 @@ import dataclasses import re from enum import Enum -from typing import Annotated, Optional, TypeVar, Union +from typing import Annotated, Optional, TypeVar import pytest @@ -82,7 +82,7 @@ class TransitiveVerb: def test_optional(): @strawberry.type class HasChoices: - decision: Optional[bool] + decision: bool | None field: StrawberryField = get_object_definition(HasChoices).fields[0] @@ -110,7 +110,7 @@ class Europe: class UK: name: str - EU = Annotated[Union[Europe, UK], strawberry.union("EU")] + EU = Annotated[Europe | UK, strawberry.union("EU")] @strawberry.type class WishfulThinking: @@ -140,11 +140,11 @@ def test_fields_with_defaults_inheritance(): @strawberry.interface class A: text: str - delay: Optional[int] = None + delay: int | None = None @strawberry.type class B(A): - attachments: Optional[list[A]] = None + attachments: list[A] | None = None @strawberry.type class C(A): diff --git a/tests/types/test_resolver_types.py b/tests/types/test_resolver_types.py index 3e036a393f..2c7d8708f4 100644 --- a/tests/types/test_resolver_types.py +++ b/tests/types/test_resolver_types.py @@ -68,7 +68,7 @@ def get_2d_object() -> Polygon: def test_optional(): - def stock_market_tool() -> Optional[str]: ... + def stock_market_tool() -> str | None: ... resolver = StrawberryResolver(stock_market_tool) assert resolver.type == Optional[str] @@ -92,7 +92,7 @@ class Venn: class Diagram: bar: float - def get_overlap() -> Union[Venn, Diagram]: ... + def get_overlap() -> Venn | Diagram: ... resolver = StrawberryResolver(get_overlap) assert resolver.type == Union[Venn, Diagram] diff --git a/tests/utils/test_arguments_converter.py b/tests/utils/test_arguments_converter.py index de68b37314..9ac1bf29a2 100644 --- a/tests/utils/test_arguments_converter.py +++ b/tests/utils/test_arguments_converter.py @@ -74,12 +74,12 @@ def test_list(): StrawberryArgument( graphql_name="optionalIntegerList", python_name="optional_integer_list", - type_annotation=StrawberryAnnotation(list[Optional[int]]), + type_annotation=StrawberryAnnotation(list[int | None]), ), StrawberryArgument( graphql_name="optionalStringList", python_name="optional_string_list", - type_annotation=StrawberryAnnotation(list[Optional[str]]), + type_annotation=StrawberryAnnotation(list[str | None]), ), ] @@ -265,7 +265,7 @@ class ReleaseFileStatus(Enum): class AddReleaseFileCommentInput: pr_number: int status: ReleaseFileStatus - release_info: Optional[ReleaseInfo] + release_info: ReleaseInfo | None args = { "input": { @@ -361,8 +361,8 @@ class Number: @strawberry.input class Input: - numbers: Optional[Number] = UNSET - numbers_second: Optional[Number] = UNSET + numbers: Number | None = UNSET + numbers_second: Number | None = UNSET # case 1 args = {"input": {}} @@ -408,8 +408,8 @@ class Number: @strawberry.input class Input: - numbers: Optional[Number] = UNSET - numbers_second: Optional[Number] = UNSET + numbers: Number | None = UNSET + numbers_second: Number | None = UNSET args = {} diff --git a/tests/utils/test_typing.py b/tests/utils/test_typing.py index 0a6fe06122..ffef81f728 100644 --- a/tests/utils/test_typing.py +++ b/tests/utils/test_typing.py @@ -12,13 +12,10 @@ class Fruit: ... def test_get_optional_annotation(): # Pair Union - assert get_optional_annotation(Optional[Union[str, bool]]) == Union[str, bool] + assert get_optional_annotation(Optional[str | bool]) == Union[str, bool] # More than pair Union - assert ( - get_optional_annotation(Optional[Union[str, int, bool]]) - == Union[str, int, bool] - ) + assert get_optional_annotation(Optional[str | int | bool]) == Union[str, int, bool] def test_eval_type(): diff --git a/tests/utils/test_typing_forward_refs.py b/tests/utils/test_typing_forward_refs.py index a14b1174c5..7b8bb517f6 100644 --- a/tests/utils/test_typing_forward_refs.py +++ b/tests/utils/test_typing_forward_refs.py @@ -18,11 +18,11 @@ class Foo: ... ) assert ( eval_type(ForwardRef("list[Foo | str] | None"), globals(), locals()) - == Union[list[Union[Foo, str]], None] + == Union[list[Foo | str], None] ) assert ( eval_type(ForwardRef("list[Foo | str] | None | int"), globals(), locals()) - == Union[list[Union[Foo, str]], int, None] + == Union[list[Foo | str], int, None] ) assert eval_type(ForwardRef("JSON | None"), globals(), locals()) == Optional[JSON] @@ -38,11 +38,11 @@ class Foo: ... ) assert ( eval_type(ForwardRef("list[Foo | str] | None"), globals(), locals()) - == Union[list[Union[Foo, str]], None] # type: ignore + == Union[list[Foo | str], None] # type: ignore ) assert ( eval_type(ForwardRef("list[Foo | str] | None | int"), globals(), locals()) - == Union[list[Union[Foo, str]], int, None] # type: ignore + == Union[list[Foo | str], int, None] # type: ignore ) diff --git a/tests/views/schema.py b/tests/views/schema.py index b78f6ee3dd..d078b99642 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -2,7 +2,7 @@ import contextlib from collections.abc import AsyncGenerator from enum import Enum -from typing import Any, Optional, Union +from typing import Any from graphql import GraphQLError from graphql.version import VersionInfo, version_info @@ -67,7 +67,7 @@ class FolderInput: @strawberry.type class DebugInfo: num_active_result_handlers: int - is_connection_init_timeout_task_done: Optional[bool] + is_connection_init_timeout_task_done: bool | None @strawberry.type @@ -90,16 +90,16 @@ def greetings(self) -> str: return "hello" @strawberry.field - def hello(self, name: Optional[str] = None) -> str: + def hello(self, name: str | None = None) -> str: return f"Hello {name or 'world'}" @strawberry.field - async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str: + async def async_hello(self, name: str | None = None, delay: float = 0) -> str: await asyncio.sleep(delay) return f"Hello {name or 'world'}" @strawberry.field(permission_classes=[AlwaysFailPermission]) - def always_fail(self) -> Optional[str]: + def always_fail(self) -> str | None: return "Hey" @strawberry.field @@ -111,7 +111,7 @@ async def exception(self, message: str) -> str: raise ValueError(message) @strawberry.field - async def some_error(self) -> Optional[str]: + async def some_error(self) -> str | None: raise ValueError("Some error") @strawberry.field @@ -270,8 +270,8 @@ async def debug(self, info: strawberry.Info) -> AsyncGenerator[DebugInfo, None]: async def listener( self, info: strawberry.Info, - timeout: Optional[float] = None, - group: Optional[str] = None, + timeout: float | None = None, + group: str | None = None, ) -> AsyncGenerator[str, None]: yield info.context["request"].channel_name @@ -287,9 +287,9 @@ async def listener( async def listener_with_confirmation( self, info: strawberry.Info, - timeout: Optional[float] = None, - group: Optional[str] = None, - ) -> AsyncGenerator[Union[str, None], None]: + timeout: float | None = None, + group: str | None = None, + ) -> AsyncGenerator[str | None, None]: async with info.context["request"].listen_to_channel( type="test.message", timeout=timeout, @@ -320,7 +320,7 @@ async def long_finalizer( class Schema(strawberry.Schema): def process_errors( - self, errors: list, execution_context: Optional[ExecutionContext] = None + self, errors: list, execution_context: ExecutionContext | None = None ) -> None: import traceback diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 12f8541edd..d5a212d4d2 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -6,7 +6,7 @@ import time from collections.abc import AsyncGenerator from datetime import timedelta -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from unittest.mock import AsyncMock, Mock, patch import pytest @@ -53,7 +53,7 @@ def assert_next( next_message: NextMessage, id: str, data: dict[str, object], - extensions: Optional[dict[str, object]] = None, + extensions: dict[str, object] | None = None, ): """ Assert that the NextMessage payload contains the provided data. @@ -1076,9 +1076,9 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): ) while True: - next_or_complete_message: Union[ - NextMessage, CompleteMessage - ] = await ws.receive_json() + next_or_complete_message: ( + NextMessage | CompleteMessage + ) = await ws.receive_json() assert next_or_complete_message["type"] in ("next", "complete") diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 0635185036..b41a127d3b 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -3,7 +3,7 @@ import asyncio import json from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -313,9 +313,9 @@ async def test_sends_keep_alive(http_client_class: type[HttpClient]): # get but they should be more than one. keepalive_count = 0 while True: - ka_or_data_message: Union[ - ConnectionKeepAliveMessage, DataMessage - ] = await ws.receive_json() + ka_or_data_message: ( + ConnectionKeepAliveMessage | DataMessage + ) = await ws.receive_json() if ka_or_data_message["type"] == "ka": keepalive_count += 1 else: diff --git a/tests/websockets/views.py b/tests/websockets/views.py index bad7e989f5..f33af88b1a 100644 --- a/tests/websockets/views.py +++ b/tests/websockets/views.py @@ -1,5 +1,3 @@ -from typing import Union - from strawberry import UNSET from strawberry.exceptions import ConnectionRejectionError from strawberry.http.async_base_view import AsyncBaseHTTPView @@ -26,7 +24,7 @@ class OnWSConnectMixin( ): async def on_ws_connect( self, context: dict[str, object] - ) -> Union[UnsetType, None, dict[str, object]]: + ) -> UnsetType | None | dict[str, object]: connection_params = context["connection_params"] if isinstance(connection_params, dict):