Skip to content

Commit b558560

Browse files
committed
use default process-based Dask distributed cluster
1 parent 4c102d1 commit b558560

File tree

2 files changed

+27
-31
lines changed

2 files changed

+27
-31
lines changed

jupyter_scheduler/job_files_manager.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import random
33
import tarfile
4-
from typing import Awaitable, Dict, List, Optional, Type
4+
from multiprocessing import Process
5+
from typing import Dict, List, Optional, Type
56

67
import fsspec
78
from dask.distributed import Client as DaskClient
@@ -14,10 +15,7 @@
1415
class JobFilesManager:
1516
scheduler = None
1617

17-
def __init__(
18-
self,
19-
scheduler: Type[BaseScheduler],
20-
):
18+
def __init__(self, scheduler: Type[BaseScheduler]):
2119
self.scheduler = scheduler
2220

2321
async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = False):
@@ -26,17 +24,20 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals
2624
output_filenames = self.scheduler.get_job_filenames(job)
2725
output_dir = self.scheduler.get_local_output_path(model=job, root_dir_relative=True)
2826

29-
dask_client: DaskClient = await self.scheduler.dask_client_future
30-
dask_client.submit(
31-
Downloader(
32-
output_formats=job.output_formats,
33-
output_filenames=output_filenames,
34-
staging_paths=staging_paths,
35-
output_dir=output_dir,
36-
redownload=redownload,
37-
include_staging_files=job.package_input_folder,
38-
).download
39-
)
27+
download = Downloader(
28+
output_formats=job.output_formats,
29+
output_filenames=output_filenames,
30+
staging_paths=staging_paths,
31+
output_dir=output_dir,
32+
redownload=redownload,
33+
include_staging_files=job.package_input_folder,
34+
).download
35+
if self.scheduler.dask_client:
36+
dask_client: DaskClient = self.scheduler.dask_client
37+
dask_client.submit(download)
38+
else:
39+
p = Process(target=download)
40+
p.start()
4041

4142

4243
class Downloader:

jupyter_scheduler/scheduler.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import random
44
import shutil
5-
from typing import Awaitable, Dict, List, Optional, Type, Union
5+
from typing import Dict, List, Optional, Type, Union
66

77
import fsspec
88
import psutil
@@ -421,12 +421,11 @@ def __init__(
421421
if self.task_runner_class:
422422
self.task_runner = self.task_runner_class(scheduler=self, config=config)
423423

424-
loop = asyncio.get_event_loop()
425-
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())
424+
self.dask_client: DaskClient = self._get_dask_client()
426425

427-
async def _get_dask_client(self):
426+
def _get_dask_client(self):
428427
"""Creates and configures a Dask client."""
429-
return DaskClient(processes=False, asynchronous=True)
428+
return DaskClient()
430429

431430
@property
432431
def db_session(self):
@@ -451,7 +450,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]:
451450
destination_dir=staging_dir,
452451
)
453452

454-
async def create_job(self, model: CreateJob) -> str:
453+
def create_job(self, model: CreateJob) -> str:
455454
if not model.job_definition_id and not self.file_exists(model.input_uri):
456455
raise InputUriError(model.input_uri)
457456

@@ -492,8 +491,7 @@ async def create_job(self, model: CreateJob) -> str:
492491
else:
493492
self.copy_input_file(model.input_uri, staging_paths["input"])
494493

495-
dask_client: DaskClient = await self.dask_client_future
496-
future = dask_client.submit(
494+
future = self.dask_client.submit(
497495
self.execution_manager_class(
498496
job_id=job.job_id,
499497
staging_paths=staging_paths,
@@ -755,16 +753,14 @@ def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinit
755753

756754
return list_response
757755

758-
async def create_job_from_definition(
759-
self, job_definition_id: str, model: CreateJobFromDefinition
760-
):
756+
def create_job_from_definition(self, job_definition_id: str, model: CreateJobFromDefinition):
761757
job_id = None
762758
definition = self.get_job_definition(job_definition_id)
763759
if definition:
764760
input_uri = self.get_staging_paths(definition)["input"]
765761
attributes = definition.dict(exclude={"schedule", "timezone"}, exclude_none=True)
766762
attributes = {**attributes, **model.dict(exclude_none=True), "input_uri": input_uri}
767-
job_id = await self.create_job(CreateJob(**attributes))
763+
job_id = self.create_job(CreateJob(**attributes))
768764

769765
return job_id
770766

@@ -789,9 +785,8 @@ async def stop_extension(self):
789785
"""
790786
Cleanup code to run when the server is stopping.
791787
"""
792-
if self.dask_client_future:
793-
dask_client: DaskClient = await self.dask_client_future
794-
await dask_client.close()
788+
if self.dask_client:
789+
self.dask_client.close()
795790

796791

797792
class ArchivingScheduler(Scheduler):

0 commit comments

Comments
 (0)