Skip to content

Commit d753850

Browse files
committed
add stop_extension logic, use it for stopping dask
1 parent d7c1fec commit d753850

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

jupyter_scheduler/extension.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,25 @@ def initialize_settings(self):
9292
if scheduler.task_runner:
9393
loop = asyncio.get_event_loop()
9494
loop.create_task(scheduler.task_runner.start())
95+
96+
async def stop_extension(self):
97+
"""
98+
Public method called by Jupyter Server when the server is stopping.
99+
This calls the cleanup code defined in `self._stop_exception()` inside
100+
an exception handler, as the server halts if this method raises an
101+
exception.
102+
"""
103+
try:
104+
await self._stop_extension()
105+
except Exception as e:
106+
self.log.error("Jupyter Scheduler raised an exception while stopping:")
107+
self.log.exception(e)
108+
109+
async def _stop_extension(self):
110+
"""
111+
Private method that defines the cleanup code to run when the server is
112+
stopping.
113+
"""
114+
if "scheduler" in self.settings:
115+
scheduler: SchedulerApp = self.settings["scheduler"]
116+
await scheduler.stop_extension()

jupyter_scheduler/scheduler.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,12 @@ def _default_staging_path(self):
9797
)
9898

9999
def __init__(
100-
self,
101-
root_dir: str,
102-
environments_manager: Type[EnvironmentManager],
103-
config=None,
104-
**kwargs,
100+
self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs
105101
):
106102
super().__init__(config=config, **kwargs)
107103
self.root_dir = root_dir
108104
self.environments_manager = environments_manager
109105

110-
loop = asyncio.get_event_loop()
111-
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())
112-
113-
async def _get_dask_client(self):
114-
"""Creates and configures a Dask client."""
115-
return DaskClient(processes=False, asynchronous=True)
116-
117106
def create_job(self, model: CreateJob) -> str:
118107
"""Creates a new job record, may trigger execution of the job.
119108
In case a task runner is actually handling execution of the jobs,
@@ -393,6 +382,12 @@ def get_local_output_path(
393382
else:
394383
return os.path.join(self.root_dir, self.output_directory, output_dir_name)
395384

385+
async def stop_extension(self):
386+
"""
387+
Placeholder method for a cleanup code to run when the server is stopping.
388+
"""
389+
pass
390+
396391

397392
class Scheduler(BaseScheduler):
398393
_db_session = None
@@ -426,6 +421,13 @@ def __init__(
426421
if self.task_runner_class:
427422
self.task_runner = self.task_runner_class(scheduler=self, config=config)
428423

424+
loop = asyncio.get_event_loop()
425+
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())
426+
427+
async def _get_dask_client(self):
428+
"""Creates and configures a Dask client."""
429+
return DaskClient(processes=False, asynchronous=True)
430+
429431
@property
430432
def db_session(self):
431433
if not self._db_session:
@@ -783,6 +785,14 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
783785

784786
return staging_paths
785787

788+
async def stop_extension(self):
789+
"""
790+
Cleanup code to run when the server is stopping.
791+
"""
792+
if self.dask_client_future:
793+
dask_client: DaskClient = await self.dask_client_future
794+
await dask_client.close()
795+
786796

787797
class ArchivingScheduler(Scheduler):
788798
"""Scheduler that captures all files in output directory in an archive."""

0 commit comments

Comments
 (0)