Skip to content

Commit eb84326

Browse files
committed
Add async benchmark.
1 parent 4ad5bcd commit eb84326

File tree

3 files changed

+110
-24
lines changed

3 files changed

+110
-24
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ sentinel = ">=0.3,<1.1"
4646
greenlet = {version = ">=3.0.0rc1", python = ">=3.12"}
4747

4848
[tool.poetry.group.dev.dependencies]
49+
"testing.postgresql" = ">=1.3.0"
50+
asgiref = "^3.7.2"
4951
asyncpg = "^0.28.0"
5052
black = ">=22,<24"
5153
importlib-metadata = ">=4.11.1,<7.0.0"
@@ -60,11 +62,10 @@ pytest-asyncio = ">=0.20.3,<0.22.0"
6062
pytest-codspeed = "^2.0.1"
6163
pytest-cov = "^4.0.0"
6264
pytest-emoji = "^0.2.0"
65+
pytest-mock = "^3.11.1"
6366
pytest-mypy-plugins = ">=1.10,<4.0"
6467
pytest-xdist = {extras = ["psutil"], version = "^3.1.0"}
6568
setuptools = ">=67.8.0"
66-
"testing.postgresql" = ">=1.3.0"
67-
asgiref = "^3.7.2"
6869

6970
[tool.black]
7071
line-length = 88

src/strawberry_sqlalchemy_mapper/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(
4343

4444
async def _scalars(self, *args, **kwargs):
4545
if self._async_bind_factory:
46-
return await self._async_bind_factory().scalars(*args, **kwargs)
46+
async with self._async_bind_factory() as bind:
47+
return await bind.scalars(*args, **kwargs)
4748
else:
4849
# Deprecated, but supported for now.
4950
assert self._bind is not None

tests/benchmarks/test_relationship_loading.py

Lines changed: 105 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from typing import List
23

34
import pytest
@@ -7,13 +8,12 @@
78
import strawberry_sqlalchemy_mapper
89
from asgiref.sync import async_to_sync
910
from pytest_codspeed.plugin import BenchmarkFixture
11+
from sqlalchemy import orm
1012
from strawberry.types import Info
1113

1214

13-
@pytest.mark.benchmark
14-
def test_load_many_relationships(
15-
benchmark: BenchmarkFixture, engine, base, sessionmaker
16-
):
15+
@pytest.fixture
16+
def populated_tables(engine, base, sessionmaker):
1717
class A(base):
1818
__tablename__ = "a"
1919
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
@@ -48,24 +48,7 @@ class Parent(base):
4848
d = sa.orm.relationship("D", backref="parents")
4949
e = sa.orm.relationship("E", backref="parents")
5050

51-
mapper = strawberry_sqlalchemy_mapper.StrawberrySQLAlchemyMapper()
52-
53-
@mapper.type(Parent)
54-
class StrawberryParent:
55-
pass
56-
57-
@strawberry.type
58-
class Query:
59-
@strawberry.field
60-
@staticmethod
61-
async def parents(info: Info) -> List[StrawberryParent]:
62-
return info.context["session"].scalars(sa.select(Parent)).all()
63-
64-
mapper.finalize()
6551
base.metadata.create_all(engine)
66-
67-
schema = strawberry.Schema(Query)
68-
6952
with sessionmaker() as session:
7053
for _ in range(1000):
7154
session.add(A())
@@ -85,6 +68,43 @@ async def parents(info: Info) -> List[StrawberryParent]:
8568
session.add(parent)
8669
session.commit()
8770

71+
return A, B, C, D, E, Parent
72+
73+
74+
@pytest.mark.benchmark
75+
def test_load_many_relationships(
76+
benchmark: BenchmarkFixture, populated_tables, sessionmaker, mocker
77+
):
78+
A, B, C, D, E, Parent = populated_tables
79+
80+
mapper = strawberry_sqlalchemy_mapper.StrawberrySQLAlchemyMapper()
81+
82+
@mapper.type(Parent)
83+
class StrawberryParent:
84+
pass
85+
86+
@strawberry.type
87+
class Query:
88+
@strawberry.field
89+
@staticmethod
90+
async def parents(info: Info) -> List[StrawberryParent]:
91+
return info.context["session"].scalars(sa.select(Parent)).all()
92+
93+
mapper.finalize()
94+
95+
schema = strawberry.Schema(Query)
96+
97+
# Now that we've seeded the database, let's add some delay to simulate network lag
98+
# to the database.
99+
old_execute_internal = orm.Session._execute_internal
100+
mocker.patch.object(orm.Session, "_execute_internal", autospec=True)
101+
102+
def sleep_then_execute(self, *args, **kwargs):
103+
time.sleep(0.01)
104+
return old_execute_internal(self, *args, **kwargs)
105+
106+
orm.Session._execute_internal.side_effect = sleep_then_execute
107+
88108
async def execute():
89109
with sessionmaker() as session:
90110
# Notice how we use a sync session but call Strawberry's async execute.
@@ -113,3 +133,67 @@ async def execute():
113133
assert len(result.data["parents"]) == 10
114134

115135
benchmark(async_to_sync(execute))
136+
137+
138+
@pytest.mark.benchmark
139+
def test_load_many_relationships_async(
140+
benchmark: BenchmarkFixture, populated_tables, async_sessionmaker, mocker
141+
):
142+
A, B, C, D, E, Parent = populated_tables
143+
144+
mapper = strawberry_sqlalchemy_mapper.StrawberrySQLAlchemyMapper()
145+
146+
@mapper.type(Parent)
147+
class StrawberryParent:
148+
pass
149+
150+
@strawberry.type
151+
class Query:
152+
@strawberry.field
153+
@staticmethod
154+
async def parents(info: Info) -> List[StrawberryParent]:
155+
async with info.context["async_sessionmaker"]() as session:
156+
return (await session.scalars(sa.select(Parent))).all()
157+
158+
mapper.finalize()
159+
160+
schema = strawberry.Schema(Query)
161+
162+
# Now that we've seeded the database, let's add some delay to simulate network lag
163+
# to the database.
164+
old_execute_internal = orm.Session._execute_internal
165+
mocker.patch.object(orm.Session, "_execute_internal", autospec=True)
166+
167+
def sleep_then_execute(self, *args, **kwargs):
168+
time.sleep(0.01)
169+
return old_execute_internal(self, *args, **kwargs)
170+
171+
orm.Session._execute_internal.side_effect = sleep_then_execute
172+
173+
async def execute():
174+
# Notice how we use a sync session but call Strawberry's async execute.
175+
# This is not an ideal combination, but it's certainly a common one that
176+
# we need to support efficiently.
177+
result = await schema.execute(
178+
"""
179+
query {
180+
parents {
181+
a { id },
182+
b { id },
183+
c { id },
184+
d { id },
185+
e { id },
186+
}
187+
}
188+
""",
189+
context_value={
190+
"async_sessionmaker": async_sessionmaker,
191+
"sqlalchemy_loader": strawberry_sqlalchemy_mapper.StrawberrySQLAlchemyLoader(
192+
async_bind_factory=async_sessionmaker
193+
),
194+
},
195+
)
196+
assert not result.errors
197+
assert len(result.data["parents"]) == 10
198+
199+
benchmark(async_to_sync(execute))

0 commit comments

Comments
 (0)