@@ -565,6 +565,7 @@ def goal(_):
565
565
566
566
self .task = self .ioloop .create_task (self ._run ())
567
567
self .saving_task = None
568
+ self .callbacks = []
568
569
if in_ipynb () and not self .ioloop .is_running ():
569
570
warnings .warn (
570
571
"The runner has been scheduled, but the asyncio "
@@ -669,6 +670,31 @@ def elapsed_time(self):
669
670
end_time = time .time ()
670
671
return end_time - self .start_time
671
672
673
+ def add_periodic_callback (
674
+ self ,
675
+ method : Callable [[AsyncRunner ]],
676
+ interval : int = 30 ,
677
+ ):
678
+ """Start a periodic callback that calls the given method on the runner.
679
+
680
+ Parameters
681
+ ----------
682
+ method : callable
683
+ The method to call periodically.
684
+ interval : int
685
+ The interval in seconds between the calls.
686
+ """
687
+
688
+ async def _callback ():
689
+ while self .status () == "running" :
690
+ method (self )
691
+ await asyncio .sleep (interval )
692
+ method (self ) # one last time
693
+
694
+ task = self .ioloop .create_task (_callback ())
695
+ self .callbacks .append (task )
696
+ return task
697
+
672
698
def start_periodic_saving (
673
699
self ,
674
700
save_kwargs : dict [str , Any ] | None = None ,
@@ -697,6 +723,8 @@ def start_periodic_saving(
697
723
... save_kwargs=dict(fname='data/test.pickle'),
698
724
... interval=600)
699
725
"""
726
+ if self .saving_task is not None :
727
+ raise RuntimeError ("Already saving." )
700
728
701
729
def default_save (learner ):
702
730
learner .save (** save_kwargs )
@@ -706,13 +734,7 @@ def default_save(learner):
706
734
if save_kwargs is None :
707
735
raise ValueError ("Must provide `save_kwargs` if method=None." )
708
736
709
- async def _saver ():
710
- while self .status () == "running" :
711
- method (self .learner )
712
- await asyncio .sleep (interval )
713
- method (self .learner ) # one last time
714
-
715
- self .saving_task = self .ioloop .create_task (_saver ())
737
+ self .saving_task = self .add_periodic_callback (method , interval = interval )
716
738
return self .saving_task
717
739
718
740
0 commit comments