3939from .protocol import blocking_proto # type: ignore [attr-defined, unused-ignore]
4040from .protocol .protocol import InputLanguage , OutputFormat
4141
42- from ._internal ._save import make_save_executor_constructor
42+ from ._internal ._save import (
43+ QueryBatch ,
44+ QueryRefetch ,
45+ SaveExecutor ,
46+ make_save_executor_constructor ,
47+ )
4348
4449if typing .TYPE_CHECKING :
4550 from ._internal ._qbmodel ._pydantic import GelModel
@@ -651,6 +656,7 @@ class Client(
651656
652657 __slots__ = ()
653658 _impl_class = _PoolImpl
659+ _save_executor_type = SaveExecutor
654660
655661 def _save_impl (
656662 self ,
@@ -659,12 +665,9 @@ def _save_impl(
659665 objs : tuple [GelModel , ...],
660666 warn_on_large_sync_set : bool = False ,
661667 ) -> None :
662- opts = self ._get_debug_options ()
663-
664- make_executor = make_save_executor_constructor (
665- objs ,
668+ make_executor = self ._get_make_save_executor (
666669 refetch = refetch ,
667- save_postcheck = opts . save_postcheck ,
670+ objs = objs ,
668671 warn_on_large_sync_set = warn_on_large_sync_set ,
669672 )
670673
@@ -675,22 +678,53 @@ def _save_impl(
675678 with executor :
676679 for batches in executor :
677680 for batch in batches :
678- tx . send_query ( batch . query , batch . args )
681+ self . _send_batch_query ( tx , batch )
679682 batch_ids = tx .wait ()
680683 for ids , batch in zip (batch_ids , batches , strict = True ):
681684 batch .feed_db_data (ids )
682685
683686 if refetch :
684687 ref_queries = executor .get_refetch_queries ()
685688 for ref in ref_queries :
686- tx . send_query ( ref . query , ** ref . args )
689+ self . _send_refetch_query ( tx , ref )
687690
688691 refetch_data = tx .wait ()
689692 for ref_data , ref in zip (
690693 refetch_data , ref_queries , strict = True
691694 ):
692695 ref .feed_db_data (ref_data )
693696
697+ def _get_make_save_executor (
698+ self ,
699+ * ,
700+ refetch : bool ,
701+ objs : tuple [GelModel , ...],
702+ warn_on_large_sync_set : bool = False ,
703+ ) -> typing .Callable [[], SaveExecutor ]:
704+ opts = self ._get_debug_options ()
705+
706+ return make_save_executor_constructor (
707+ objs ,
708+ refetch = refetch ,
709+ save_postcheck = opts .save_postcheck ,
710+ warn_on_large_sync_set = warn_on_large_sync_set ,
711+ executor_type = self ._save_executor_type ,
712+ )
713+
714+ def _send_batch_query (
715+ self ,
716+ tx : BatchIteration ,
717+ batch : QueryBatch ,
718+ ) -> None :
719+ tx .send_query (batch .query , batch .args )
720+
721+ def _send_refetch_query (
722+ self ,
723+ tx : BatchIteration ,
724+ ref : QueryRefetch ,
725+ ) -> None :
726+ tx .send_query (ref .query , ** ref .args )
727+
694728 def save (
695729 self ,
696730 * objs : GelModel ,
@@ -718,6 +752,7 @@ def __debug_save__(self, *objs: GelModel) -> SaveDebug:
718752 make_executor = make_save_executor_constructor (
719753 objs ,
720754 refetch = False , # TODO
755+ executor_type = self ._save_executor_type ,
721756 )
722757 plan_time = time .monotonic_ns () - ns
723758
0 commit comments