@@ -97,23 +97,12 @@ def _default_staging_path(self):
97
97
)
98
98
99
99
def __init__ (
100
- self ,
101
- root_dir : str ,
102
- environments_manager : Type [EnvironmentManager ],
103
- config = None ,
104
- ** kwargs ,
100
+ self , root_dir : str , environments_manager : Type [EnvironmentManager ], config = None , ** kwargs
105
101
):
106
102
super ().__init__ (config = config , ** kwargs )
107
103
self .root_dir = root_dir
108
104
self .environments_manager = environments_manager
109
105
110
- loop = asyncio .get_event_loop ()
111
- self .dask_client_future : Awaitable [DaskClient ] = loop .create_task (self ._get_dask_client ())
112
-
113
- async def _get_dask_client (self ):
114
- """Creates and configures a Dask client."""
115
- return DaskClient (processes = False , asynchronous = True )
116
-
117
106
def create_job (self , model : CreateJob ) -> str :
118
107
"""Creates a new job record, may trigger execution of the job.
119
108
In case a task runner is actually handling execution of the jobs,
@@ -393,6 +382,12 @@ def get_local_output_path(
393
382
else :
394
383
return os .path .join (self .root_dir , self .output_directory , output_dir_name )
395
384
385
+ async def stop_extension (self ):
386
+ """
387
+ Placeholder method for a cleanup code to run when the server is stopping.
388
+ """
389
+ pass
390
+
396
391
397
392
class Scheduler (BaseScheduler ):
398
393
_db_session = None
@@ -426,6 +421,13 @@ def __init__(
426
421
if self .task_runner_class :
427
422
self .task_runner = self .task_runner_class (scheduler = self , config = config )
428
423
424
+ loop = asyncio .get_event_loop ()
425
+ self .dask_client_future : Awaitable [DaskClient ] = loop .create_task (self ._get_dask_client ())
426
+
427
+ async def _get_dask_client (self ):
428
+ """Creates and configures a Dask client."""
429
+ return DaskClient (processes = False , asynchronous = True )
430
+
429
431
@property
430
432
def db_session (self ):
431
433
if not self ._db_session :
@@ -783,6 +785,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
783
785
784
786
return staging_paths
785
787
788
+ async def stop_extension (self ):
789
+ """
790
+ Cleanup code to run when the server is stopping.
791
+ """
792
+ if self .dask_client_future :
793
+ dask_client : DaskClient = await self .dask_client_future
794
+ await dask_client .close ()
795
+
786
796
787
797
class ArchivingScheduler (Scheduler ):
788
798
"""Scheduler that captures all files in output directory in an archive."""
0 commit comments