Skip to content

Commit c8a2b76

Browse files
committed
Add profiling for sync.
1 parent 460eee7 commit c8a2b76

File tree

5 files changed

+939
-24
lines changed

5 files changed

+939
-24
lines changed

gel/_internal/_save.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1160,13 +1160,14 @@ def make_save_executor_constructor(
11601160
refetch: bool,
11611161
warn_on_large_sync_set: bool = False,
11621162
save_postcheck: bool = False,
1163+
executor_type: type,
11631164
) -> Callable[[], SaveExecutor]:
11641165
plan = make_plan(
11651166
objs,
11661167
refetch=refetch,
11671168
warn_on_large_sync_set=warn_on_large_sync_set,
11681169
)
1169-
return lambda: SaveExecutor(
1170+
return lambda: executor_type(
11701171
objs=objs,
11711172
create_batches=plan.create_batches,
11721173
updates=plan.update_batch,

gel/_internal/_testbase/_base.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,7 @@ class BranchTestCase(InstanceTestCase):
630630
BASE_TEST_CLASS = True
631631
TEARDOWN_RETRY_DROP_DB = 1
632632

633+
CLIENT_TYPE: ClassVar[type[TestClient | TestAsyncIOClient] | None]
633634
client: ClassVar[TestClient | TestAsyncIOClient]
634635

635636
@classmethod
@@ -678,7 +679,9 @@ def setUp(self) -> None:
678679
if self.ISOLATED_TEST_BRANCHES:
679680
cls = type(self)
680681
testdb = cls.loop.run_until_complete(self.setup_branch_copy())
681-
client = cls.make_test_client(database=testdb)._with_debug(
682+
client = cls.make_test_client(
683+
database=testdb, client_class=self.CLIENT_TYPE
684+
)._with_debug(
682685
save_postcheck=True,
683686
)
684687
self.client = client # type: ignore[misc]
@@ -717,6 +720,7 @@ def tearDown(self) -> None:
717720
def make_test_client(
718721
cls,
719722
*,
723+
client_class: type[TestClient | TestAsyncIOClient] | None = None,
720724
connection_class: type[
721725
asyncio_client.AsyncIOConnection
722726
| blocking_client.BlockingIOConnection
@@ -758,14 +762,17 @@ def make_blocking_test_client(
758762
cls,
759763
*,
760764
instance: _server.BaseInstance,
765+
client_class: type[TestClient] | None = None,
761766
connection_class: type[blocking_client.BlockingIOConnection]
762767
| None = None,
763768
**kwargs: str,
764769
) -> TestClient:
770+
if client_class is None:
771+
client_class = TestClient
765772
if connection_class is None:
766773
connection_class = blocking_client.BlockingIOConnection
767774
client = instance.create_blocking_client(
768-
client_class=TestClient,
775+
client_class=client_class,
769776
connection_class=connection_class,
770777
**cls.get_connect_args(instance, **kwargs),
771778
)
@@ -799,13 +806,16 @@ def make_async_test_client(
799806
cls,
800807
*,
801808
instance: _server.BaseInstance,
809+
client_class: type[TestAsyncIOClient] | None = None,
802810
connection_class: type[asyncio_client.AsyncIOConnection] | None = None,
803811
**kwargs: str,
804812
) -> TestAsyncIOClient:
813+
if client_class is None:
814+
client_class = TestAsyncIOClient
805815
if connection_class is None:
806816
connection_class = asyncio_client.AsyncIOConnection
807817
client = instance.create_async_client(
808-
client_class=TestAsyncIOClient,
818+
client_class=client_class,
809819
connection_class=connection_class,
810820
**cls.get_connect_args(instance, **kwargs),
811821
)
@@ -881,7 +891,9 @@ async def setup_and_connect(cls) -> None:
881891
await cls._create_empty_branch(dbname)
882892

883893
if not cls.ISOLATED_TEST_BRANCHES:
884-
cls.client = cls.make_test_client(database=dbname)
894+
cls.client = cls.make_test_client(
895+
database=dbname, client_class=cls.CLIENT_TYPE
896+
)
885897
if isinstance(cls.client, gel.AsyncIOClient):
886898
await cls.client.ensure_connected()
887899
else:
@@ -1021,11 +1033,13 @@ class AsyncQueryTestCase(BranchTestCase):
10211033
def make_test_client( # pyright: ignore [reportIncompatibleMethodOverride]
10221034
cls,
10231035
*,
1036+
client_class: type[TestAsyncIOClient] | None = None,
10241037
connection_class: type[asyncio_client.AsyncIOConnection] | None = None, # type: ignore [override]
10251038
**kwargs: str,
10261039
) -> TestAsyncIOClient:
10271040
return cls.make_async_test_client(
10281041
instance=cls.instance,
1042+
client_class=client_class,
10291043
connection_class=connection_class,
10301044
**kwargs,
10311045
)
@@ -1062,12 +1076,14 @@ def adapt_call(cls, coro: Any) -> Any:
10621076
def make_test_client( # pyright: ignore [reportIncompatibleMethodOverride]
10631077
cls,
10641078
*,
1079+
client_class: type[TestClient] | None = None,
10651080
connection_class: type[blocking_client.BlockingIOConnection] # type: ignore [override]
10661081
| None = None,
10671082
**kwargs: str,
10681083
) -> TestClient:
10691084
return cls.make_blocking_test_client(
10701085
instance=cls.instance,
1086+
client_class=client_class,
10711087
connection_class=connection_class,
10721088
**kwargs,
10731089
)

gel/asyncio_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .protocol import asyncio_proto # type: ignore [attr-defined, unused-ignore]
3838
from .protocol.protocol import InputLanguage, OutputFormat
3939

40-
from ._internal._save import make_save_executor_constructor
40+
from ._internal._save import make_save_executor_constructor, SaveExecutor
4141

4242
if typing.TYPE_CHECKING:
4343
from ._internal._qbmodel._pydantic import GelModel
@@ -675,6 +675,7 @@ async def _save_impl(
675675
refetch=refetch,
676676
save_postcheck=opts.save_postcheck,
677677
warn_on_large_sync_set=warn_on_large_sync_set,
678+
executor_type=SaveExecutor,
678679
)
679680

680681
async for tx in self._batch():

gel/blocking_client.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
from .protocol import blocking_proto # type: ignore [attr-defined, unused-ignore]
3939
from .protocol.protocol import InputLanguage, OutputFormat
4040

41-
from ._internal._save import make_save_executor_constructor
41+
from ._internal._save import (
42+
QueryBatch,
43+
QueryRefetch,
44+
SaveExecutor,
45+
make_save_executor_constructor,
46+
)
4247

4348
if typing.TYPE_CHECKING:
4449
from ._internal._qbmodel._pydantic import GelModel
@@ -681,6 +686,7 @@ class Client(
681686

682687
__slots__ = ()
683688
_impl_class = _PoolImpl
689+
_save_executor_type = SaveExecutor
684690

685691
def _save_impl(
686692
self,
@@ -689,12 +695,9 @@ def _save_impl(
689695
objs: tuple[GelModel, ...],
690696
warn_on_large_sync_set: bool = False,
691697
) -> None:
692-
opts = self._get_debug_options()
693-
694-
make_executor = make_save_executor_constructor(
695-
objs,
698+
make_executor = self._get_make_save_executor(
696699
refetch=refetch,
697-
save_postcheck=opts.save_postcheck,
700+
objs=objs,
698701
warn_on_large_sync_set=warn_on_large_sync_set,
699702
)
700703

@@ -703,23 +706,13 @@ def _save_impl(
703706
executor = make_executor()
704707

705708
for batches in executor:
706-
for batch in batches:
707-
tx.send_query(batch.query, batch.args)
708-
batch_ids = tx.wait()
709+
batch_ids = self._send_batch_queries(tx, batches)
709710
for ids, batch in zip(batch_ids, batches, strict=True):
710711
batch.record_inserted_data(ids)
711712

712713
if refetch:
713714
ref_queries = executor.get_refetch_queries()
714-
for ref in ref_queries:
715-
tx.send_query(
716-
ref.query,
717-
spec=ref.args.spec,
718-
new=ref.args.new,
719-
existing=ref.args.existing,
720-
)
721-
722-
refetch_data = tx.wait()
715+
refetch_data = self._send_refetch_queries(tx, ref_queries)
723716

724717
for ref_data, ref in zip(
725718
refetch_data, ref_queries, strict=True
@@ -728,6 +721,60 @@ def _save_impl(
728721

729722
executor.commit()
730723

724+
def _get_make_save_executor(
725+
self,
726+
*,
727+
refetch: bool,
728+
objs: tuple[GelModel, ...],
729+
warn_on_large_sync_set: bool = False,
730+
) -> typing.Callable[[], SaveExecutor]:
731+
opts = self._get_debug_options()
732+
733+
return make_save_executor_constructor(
734+
objs,
735+
refetch=refetch,
736+
save_postcheck=opts.save_postcheck,
737+
warn_on_large_sync_set=warn_on_large_sync_set,
738+
executor_type=self._save_executor_type,
739+
)
740+
741+
def _send_batch_queries(
742+
self,
743+
tx: BatchIteration,
744+
batches: list[QueryBatch],
745+
) -> list[Any]:
746+
for batch in batches:
747+
self._send_batch_query(tx, batch)
748+
return tx.wait()
749+
750+
def _send_refetch_queries(
751+
self,
752+
tx: BatchIteration,
753+
ref_queries: list[QueryRefetch],
754+
) -> list[Any]:
755+
for ref in ref_queries:
756+
self._send_refetch_query(tx, ref)
757+
return tx.wait()
758+
759+
def _send_batch_query(
760+
self,
761+
tx: BatchIteration,
762+
batch: QueryBatch,
763+
) -> None:
764+
tx.send_query(batch.query, batch.args)
765+
766+
def _send_refetch_query(
767+
self,
768+
tx: BatchIteration,
769+
ref: QueryRefetch,
770+
) -> None:
771+
tx.send_query(
772+
ref.query,
773+
spec=ref.args.spec,
774+
new=ref.args.new,
775+
existing=ref.args.existing,
776+
)
777+
731778
def save(
732779
self,
733780
*objs: GelModel,

0 commit comments

Comments
 (0)