Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions federation-compatibility/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ class Custom: ...
@strawberry.federation.type(extend=True, keys=["email"])
class User:
email: strawberry.ID = strawberry.federation.field(external=True)
name: Optional[str] = strawberry.federation.field(override="users")
total_products_created: Optional[int] = strawberry.federation.field(external=True)
name: str | None = strawberry.federation.field(override="users")
total_products_created: int | None = strawberry.federation.field(external=True)
years_of_employment: int = strawberry.federation.field(external=True)

# TODO: the camel casing will be fixed in a future release of Strawberry
@strawberry.federation.field(requires=["totalProductsCreated", "yearsOfEmployment"])
def average_products_created_per_year(self) -> Optional[int]:
def average_products_created_per_year(self) -> int | None:
if self.total_products_created is not None:
return round(self.total_products_created / self.years_of_employment)

Expand All @@ -150,9 +150,9 @@ def resolve_reference(cls, **data: Any) -> Optional["User"]:

@strawberry.federation.type(shareable=True)
class ProductDimension:
size: Optional[str]
weight: Optional[float]
unit: Optional[str] = strawberry.federation.field(inaccessible=True)
size: str | None
weight: float | None
unit: str | None = strawberry.federation.field(inaccessible=True)


@strawberry.type
Expand All @@ -163,13 +163,13 @@ class ProductVariation:
@strawberry.type
class CaseStudy:
case_number: strawberry.ID
description: Optional[str]
description: str | None


@strawberry.federation.type(keys=["study { caseNumber }"])
class ProductResearch:
study: CaseStudy
outcome: Optional[str]
outcome: str | None

@classmethod
def from_data(cls, data: dict) -> "ProductResearch":
Expand Down Expand Up @@ -206,8 +206,8 @@ def resolve_reference(cls, **data: Any) -> Optional["ProductResearch"]:
class DeprecatedProduct:
sku: str
package: str
reason: Optional[str]
created_by: Optional[User]
reason: str | None
created_by: User | None

@classmethod
def resolve_reference(cls, **data: Any) -> Optional["DeprecatedProduct"]:
Expand All @@ -231,27 +231,27 @@ def resolve_reference(cls, **data: Any) -> Optional["DeprecatedProduct"]:
)
class Product:
id: strawberry.ID
sku: Optional[str]
package: Optional[str]
sku: str | None
package: str | None
variation_id: strawberry.Private[str]

@strawberry.field
def variation(self) -> Optional[ProductVariation]:
def variation(self) -> ProductVariation | None:
return (
ProductVariation(strawberry.ID(self.variation_id))
if self.variation_id
else None
)

@strawberry.field
def dimensions(self) -> Optional[ProductDimension]:
def dimensions(self) -> ProductDimension | None:
return ProductDimension(**dimension)

@strawberry.federation.field(provides=["totalProductsCreated"])
def created_by(self) -> Optional[User]:
def created_by(self) -> User | None:
return User(**user)

notes: Optional[str] = strawberry.federation.field(tags=["internal"])
notes: str | None = strawberry.federation.field(tags=["internal"])
research: list[ProductResearch]

@classmethod
Expand Down Expand Up @@ -301,10 +301,10 @@ def resolve_reference(cls, id: strawberry.ID) -> "Inventory":

@strawberry.federation.type(extend=True)
class Query:
product: Optional[Product] = strawberry.field(resolver=get_product_by_id)
product: Product | None = strawberry.field(resolver=get_product_by_id)

@strawberry.field(deprecation_reason="Use product query instead")
def deprecated_product(self, sku: str, package: str) -> Optional[DeprecatedProduct]:
def deprecated_product(self, sku: str, package: str) -> DeprecatedProduct | None:
return None


Expand Down
9 changes: 2 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,6 @@ src = ["strawberry", "tests"]
[tool.ruff.lint]
select = ["ALL"]
ignore = [
# https://github.com/astral-sh/ruff/pull/4427
# equivalent to keep-runtime-typing. We might want to enable those
# after we drop support for Python 3.9
"UP006",
"UP007",
"UP045",

# we use asserts in tests and to hint mypy
"S101",

Expand Down Expand Up @@ -372,6 +365,8 @@ ignore = [
"TCH002",
"TCH003",
"TRY002",
"UP007",
"UP045",
]

[tool.ruff.lint.isort]
Expand Down
15 changes: 7 additions & 8 deletions strawberry/aiohttp/test/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
)

from strawberry.test.client import BaseGraphQLTestClient, Response
Expand All @@ -17,11 +16,11 @@ class GraphQLTestClient(BaseGraphQLTestClient):
async def query(
self,
query: str,
variables: Optional[dict[str, Mapping]] = None,
headers: Optional[dict[str, object]] = None,
asserts_errors: Optional[bool] = None,
files: Optional[dict[str, object]] = None,
assert_no_errors: Optional[bool] = True,
variables: dict[str, Mapping] | None = None,
headers: dict[str, object] | None = None,
asserts_errors: bool | None = None,
files: dict[str, object] | None = None,
assert_no_errors: bool | None = True,
) -> Response:
body = self._build_body(query, variables, files)

Expand Down Expand Up @@ -54,8 +53,8 @@ async def query(
async def request(
self,
body: dict[str, object],
headers: Optional[dict[str, object]] = None,
files: Optional[dict[str, object]] = None,
headers: dict[str, object] | None = None,
files: dict[str, object] | None = None,
) -> Any:
return await self._client.post(
self.url,
Expand Down
18 changes: 8 additions & 10 deletions strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from json.decoder import JSONDecodeError
from typing import (
TYPE_CHECKING,
Optional,
TypeGuard,
Union,
)

from lia import AiohttpHTTPRequestAdapter, HTTPException
Expand Down Expand Up @@ -72,7 +70,7 @@ async def close(self, code: int, reason: str) -> None:
class GraphQLView(
AsyncBaseHTTPView[
web.Request,
Union[web.Response, web.StreamResponse],
web.Response | web.StreamResponse,
web.Response,
web.Request,
web.WebSocketResponse,
Expand All @@ -91,8 +89,8 @@ class GraphQLView(
def __init__(
self,
schema: BaseSchema,
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
graphiql: bool | None = None,
graphql_ide: GraphQL_IDE | None = "graphiql",
allow_queries_via_get: bool = True,
keep_alive: bool = True,
keep_alive_interval: float = 1,
Expand Down Expand Up @@ -131,12 +129,12 @@ def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]:
ws = web.WebSocketResponse(protocols=self.subscription_protocols)
return ws.can_prepare(request).ok

async def pick_websocket_subprotocol(self, request: web.Request) -> Optional[str]:
async def pick_websocket_subprotocol(self, request: web.Request) -> str | None:
ws = web.WebSocketResponse(protocols=self.subscription_protocols)
return ws.can_prepare(request).protocol

async def create_websocket_response(
self, request: web.Request, subprotocol: Optional[str]
self, request: web.Request, subprotocol: str | None
) -> web.WebSocketResponse:
protocols = [subprotocol] if subprotocol else []
ws = web.WebSocketResponse(protocols=protocols)
Expand All @@ -152,17 +150,17 @@ async def __call__(self, request: web.Request) -> web.StreamResponse:
status=e.status_code,
)

async def get_root_value(self, request: web.Request) -> Optional[RootValue]:
async def get_root_value(self, request: web.Request) -> RootValue | None:
return None

async def get_context(
self, request: web.Request, response: Union[web.Response, web.WebSocketResponse]
self, request: web.Request, response: web.Response | web.WebSocketResponse
) -> Context:
return {"request": request, "response": response} # type: ignore

def create_response(
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
response_data: GraphQLHTTPResponse | list[GraphQLHTTPResponse],
sub_response: web.Response,
) -> web.Response:
sub_response.text = self.encode_json(response_data)
Expand Down
27 changes: 13 additions & 14 deletions strawberry/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Annotated,
Any,
ForwardRef,
Optional,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -61,14 +60,14 @@ class StrawberryAnnotation:

def __init__(
self,
annotation: Union[object, str],
annotation: object | str,
*,
namespace: Optional[dict[str, Any]] = None,
namespace: dict[str, Any] | None = None,
) -> None:
self.raw_annotation = annotation
self.namespace = namespace

self.__resolve_cache__: Optional[Union[StrawberryType, type]] = None
self.__resolve_cache__: StrawberryType | type | None = None

def __eq__(self, other: object) -> bool:
if not isinstance(other, StrawberryAnnotation):
Expand All @@ -81,8 +80,8 @@ def __hash__(self) -> int:

@staticmethod
def from_annotation(
annotation: object, namespace: Optional[dict[str, Any]] = None
) -> Optional[StrawberryAnnotation]:
annotation: object, namespace: dict[str, Any] | None = None
) -> StrawberryAnnotation | None:
if annotation is None:
return None

Expand All @@ -91,7 +90,7 @@ def from_annotation(
return annotation

@property
def annotation(self) -> Union[object, str]:
def annotation(self) -> object | str:
"""Return evaluated type on success or fallback to raw (string) annotation."""
try:
return self.evaluate()
Expand All @@ -101,7 +100,7 @@ def annotation(self) -> Union[object, str]:
return self.raw_annotation

@annotation.setter
def annotation(self, value: Union[object, str]) -> None:
def annotation(self, value: object | str) -> None:
self.raw_annotation = value

self.__resolve_cache__ = None
Expand Down Expand Up @@ -131,8 +130,8 @@ def _get_type_with_args(
def resolve(
self,
*,
type_definition: Optional[StrawberryObjectDefinition] = None,
) -> Union[StrawberryType, type]:
type_definition: StrawberryObjectDefinition | None = None,
) -> StrawberryType | type:
"""Return resolved (transformed) annotation."""
if (resolved := self.__resolve_cache__) is None:
resolved = self._resolve()
Expand Down Expand Up @@ -161,11 +160,11 @@ def resolve(

return resolved

def _resolve(self) -> Union[StrawberryType, type]:
def _resolve(self) -> StrawberryType | type:
evaled_type = cast("Any", self.evaluate())
return self._resolve_evaled_type(evaled_type)

def _resolve_evaled_type(self, evaled_type: Any) -> Union[StrawberryType, type]:
def _resolve_evaled_type(self, evaled_type: Any) -> StrawberryType | type:
if is_private(evaled_type):
return evaled_type

Expand Down Expand Up @@ -247,7 +246,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional:
# passed as we can safely use `Union` for both optional types
# (e.g. `Optional[str]`) and optional unions (e.g.
# `Optional[Union[TypeA, TypeB]]`)
child_type = Union[non_optional_types] # type: ignore
child_type = Union[non_optional_types] # type: ignore # noqa: UP007

of_type = StrawberryAnnotation(
annotation=child_type,
Expand Down Expand Up @@ -324,7 +323,7 @@ def _is_enum(cls, annotation: Any) -> bool:
return issubclass(annotation, Enum)

@classmethod
def _is_type_generic(cls, type_: Union[StrawberryType, type]) -> bool:
def _is_type_generic(cls, type_: StrawberryType | type) -> bool:
"""Returns True if `resolver_type` is generic else False."""
from strawberry.types.base import StrawberryType

Expand Down
Loading
Loading