Skip to content

Commit 0cf6f0a

Browse files
novagjuniper-moss@outlook.com
authored andcommitted
feat: support sqlalchemy select API
1 parent 9f98569 commit 0cf6f0a

File tree

4 files changed

+95
-46
lines changed

4 files changed

+95
-46
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Release type: minor
2+
3+
Support SQLAlchemy select API when resolving.

src/strawberry_sqlalchemy_mapper/field.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131
from sqlakeyset.types import Keyset
3232
from sqlalchemy.ext.asyncio import AsyncSession
33-
from sqlalchemy.orm import Query, Session
33+
from sqlalchemy.future import select
34+
from sqlalchemy.orm import Session
35+
from sqlalchemy.sql.expression import Select
3436
from strawberry import relay
3537
from strawberry.annotation import StrawberryAnnotation
3638
from strawberry.extensions.field_extension import (
@@ -59,11 +61,11 @@
5961
assert argument # type: ignore[truthy-function]
6062

6163

62-
connection_session: contextvars.ContextVar[
63-
Union[Session, AsyncSession, None]
64-
] = contextvars.ContextVar(
65-
"connection-session",
66-
default=None,
64+
connection_session: contextvars.ContextVar[Union[Session, AsyncSession, None]] = (
65+
contextvars.ContextVar(
66+
"connection-session",
67+
default=None,
68+
)
6769
)
6870

6971

@@ -97,7 +99,7 @@ def __init__(
9799
@dataclasses.dataclass
98100
class StrawberrySQLAlchemyAsyncQuery:
99101
session: AsyncSession
100-
query: Callable[[Session], Query]
102+
query: Callable[[], Select]
101103
iterator: Iterator[Any] | None = None
102104
limit: int | None = None
103105
offset: int | None = None
@@ -120,16 +122,13 @@ def __aiter__(self):
120122

121123
async def __anext__(self):
122124
if self.iterator is None:
125+
q = self.query()
126+
if self.limit is not None:
127+
q = q.limit(self.limit)
128+
if self.offset is not None:
129+
q = q.offset(self.offset)
123130

124-
def query_runner(s: Session):
125-
q = self.query(s)
126-
if self.limit is not None:
127-
q = q.limit(self.limit)
128-
if self.offset is not None:
129-
q = q.offset(self.offset)
130-
return list(q)
131-
132-
self.iterator = iter(await self.session.run_sync(query_runner))
131+
self.iterator = iter(await self.session.scalars(q))
133132

134133
try:
135134
return next(self.iterator)
@@ -325,7 +324,7 @@ def default_resolver(
325324
if session is None:
326325
session = field_sessionmaker()
327326

328-
def _get_query(s: Session):
327+
def _get_orm_query(s: Session):
329328
if root is not None:
330329
# root won't be None when resolving nested connections.
331330
# TODO: Maybe we want to send this to a dataloader?
@@ -338,16 +337,29 @@ def _get_query(s: Session):
338337

339338
return query
340339

340+
def _get_select_query():
341+
if root is not None:
342+
# root won't be None when resolving nested connections.
343+
# TODO: Maybe we want to send this to a dataloader?
344+
query = getattr(root, field.python_name)
345+
else:
346+
query = select(model)
347+
348+
if field.keyset is not None:
349+
query = query.order_by(*field.keyset)
350+
351+
return query
352+
341353
if isinstance(session, AsyncSession):
342354
return cast(
343355
Iterable[Any],
344356
StrawberrySQLAlchemyAsyncQuery(
345357
session=session,
346-
query=lambda s: _get_query(s),
358+
query=_get_select_query,
347359
),
348360
)
349361

350-
return _get_query(session)
362+
return _get_orm_query(session)
351363

352364
field.base_resolver = StrawberryResolver(default_resolver)
353365

@@ -415,8 +427,7 @@ def field(
415427
graphql_type: Any | None = None,
416428
extensions: Sequence[FieldExtension] = (),
417429
sessionmaker: _SessionMaker | None = None,
418-
) -> _T:
419-
...
430+
) -> _T: ...
420431

421432

422433
@overload
@@ -437,8 +448,7 @@ def field(
437448
graphql_type: Any | None = None,
438449
extensions: Sequence[FieldExtension] = (),
439450
sessionmaker: _SessionMaker | None = None,
440-
) -> Any:
441-
...
451+
) -> Any: ...
442452

443453

444454
@overload
@@ -459,8 +469,7 @@ def field(
459469
graphql_type: Any | None = None,
460470
extensions: Sequence[FieldExtension] = (),
461471
sessionmaker: _SessionMaker | None = None,
462-
) -> StrawberrySQLAlchemyField:
463-
...
472+
) -> StrawberrySQLAlchemyField: ...
464473

465474

466475
def field(
@@ -599,8 +608,7 @@ def connection(
599608
extensions: Sequence[FieldExtension] = (),
600609
sessionmaker: _SessionMaker | None = None,
601610
keyset: Keyset | None = None,
602-
) -> Any:
603-
...
611+
) -> Any: ...
604612

605613

606614
@overload
@@ -622,8 +630,7 @@ def connection(
622630
extensions: Sequence[FieldExtension] = (),
623631
sessionmaker: _SessionMaker | None = None,
624632
keyset: Keyset | None = None,
625-
) -> Any:
626-
...
633+
) -> Any: ...
627634

628635

629636
def connection(

src/strawberry_sqlalchemy_mapper/relay.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
)
1515

1616
import sqlakeyset
17+
import sqlakeyset.asyncio
1718
import strawberry
18-
from sqlalchemy import and_, or_
19+
from sqlalchemy.engine import Row
1920
from sqlalchemy.exc import NoResultFound
2021
from sqlalchemy.ext.asyncio import AsyncSession
2122
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
23+
from sqlalchemy.orm import Query
24+
from sqlalchemy.sql.expression import Select, and_, or_
2225
from strawberry import relay
2326
from strawberry.relay.exceptions import NodeIDAnnotationError
2427
from strawberry.relay.types import NodeType
@@ -27,7 +30,7 @@
2730
if TYPE_CHECKING:
2831
from typing_extensions import Literal, Self
2932

30-
from sqlalchemy.orm import Query, Session
33+
from sqlalchemy.orm import Session
3134
from strawberry.types.info import Info
3235
from strawberry.utils.await_maybe import AwaitableOrValue
3336

@@ -64,7 +67,7 @@ class KeysetConnection(relay.Connection[NodeType]):
6467
@classmethod
6568
def resolve_connection(
6669
cls,
67-
nodes: Union[Query, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
70+
nodes: Union[Query, Select, StrawberrySQLAlchemyAsyncQuery], # type: ignore[override]
6871
*,
6972
info: Info,
7073
before: Optional[str] = None,
@@ -110,40 +113,76 @@ def resolve_connection(page: sqlakeyset.Page):
110113
end_cursor=page.paging.get_bookmark_at(-1) if page else None,
111114
),
112115
edges=[
113-
edge_class.resolve_edge(n, cursor=page.paging.get_bookmark_at(i))
116+
edge_class.resolve_edge(
117+
n[0] if isinstance(n, Row) else n,
118+
cursor=page.paging.get_bookmark_at(i),
119+
)
114120
for i, n in enumerate(page)
115121
],
116122
)
117123

118-
def resolve_nodes(s: Session, nodes=nodes):
119-
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
120-
nodes = nodes.query(s)
124+
def resolve_nodes(s: Session, nodes: Union[Query, Select]):
125+
if isinstance(nodes, Select):
126+
return resolve_connection(
127+
sqlakeyset.select_page(
128+
s,
129+
nodes,
130+
per_page=per_page,
131+
after=(
132+
sqlakeyset.unserialize_bookmark(after).place
133+
if after
134+
else None
135+
),
136+
before=(
137+
sqlakeyset.unserialize_bookmark(before).place
138+
if before
139+
else None
140+
),
141+
)
142+
)
121143

122144
return resolve_connection(
123145
sqlakeyset.get_page(
124146
nodes,
147+
per_page=per_page,
148+
after=(
149+
sqlakeyset.unserialize_bookmark(after).place if after else None
150+
),
125151
before=(
126152
sqlakeyset.unserialize_bookmark(before).place
127153
if before
128154
else None
129155
),
156+
)
157+
)
158+
159+
async def resolve_nodes_async(s: AsyncSession, nodes: Select):
160+
# the asynchronous SQLAlchemy API only supports select
161+
return resolve_connection(
162+
await sqlakeyset.asyncio.select_page(
163+
s,
164+
nodes,
165+
per_page=per_page,
130166
after=(
131167
sqlakeyset.unserialize_bookmark(after).place if after else None
132168
),
133-
per_page=per_page,
169+
before=(
170+
sqlakeyset.unserialize_bookmark(before).place
171+
if before
172+
else None
173+
),
134174
)
135175
)
136176

137-
# TODO: It would be better to aboid session.run_sync in here but
138-
# sqlakeyset doesn't have a `get_page` async counterpart.
139177
if isinstance(session, AsyncSession):
178+
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
179+
nodes = nodes.query()
140180

141-
async def resolve_async(nodes=nodes):
142-
return await session.run_sync(lambda s: resolve_nodes(s))
143-
144-
return resolve_async()
181+
assert isinstance(nodes, Select)
182+
return resolve_nodes_async(session, nodes)
145183

146-
return resolve_nodes(session)
184+
assert isinstance(nodes, (Query, Select))
185+
return resolve_nodes(session, nodes)
147186

148187

149188
@overload

tests/relay/test_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class Query:
109109
await session.commit()
110110

111111
session.add_all([f1, f2, f3])
112-
session.commit()
112+
await session.commit()
113113

114114
for f in [f1, f2, f3]:
115115
result = await schema.execute(query, {"id": relay.to_base64("Fruit", f.id)})
@@ -266,7 +266,7 @@ class Query:
266266
await session.commit()
267267

268268
session.add_all([f1, f2, f3])
269-
session.commit()
269+
await session.commit()
270270

271271
result = await schema.execute(
272272
query,

0 commit comments

Comments
 (0)