diff --git a/conftest.py b/conftest.py index 71915074..29296764 100644 --- a/conftest.py +++ b/conftest.py @@ -5,7 +5,7 @@ from jupyter_scheduler.orm import create_session, create_tables from jupyter_scheduler.scheduler import Scheduler -from jupyter_scheduler.tests.mocks import MockEnvironmentManager +from jupyter_scheduler.tests.mocks import MockDownloadManager, MockEnvironmentManager pytest_plugins = ("jupyter_server.pytest_plugin",) @@ -48,5 +48,8 @@ def jp_scheduler_db(): @pytest.fixture def jp_scheduler(): return Scheduler( - db_url=DB_URL, root_dir=str(TEST_ROOT_DIR), environments_manager=MockEnvironmentManager() + db_url=DB_URL, + root_dir=(TEST_ROOT_DIR), + environments_manager=MockEnvironmentManager(), + download_manager=MockDownloadManager(DB_URL), ) diff --git a/jupyter_scheduler/download_manager.py b/jupyter_scheduler/download_manager.py new file mode 100644 index 00000000..1cb3fefb --- /dev/null +++ b/jupyter_scheduler/download_manager.py @@ -0,0 +1,84 @@ +from multiprocessing import Queue +from typing import List, Optional + +from jupyter_scheduler.models import DescribeDownload +from jupyter_scheduler.orm import Download, create_session, generate_uuid +from jupyter_scheduler.pydantic_v1 import BaseModel +from jupyter_scheduler.utils import get_utc_timestamp + + +def initiate_download_standalone( + job_id: str, download_queue: Queue, db_session, redownload: bool = False +): + """ + This method initiates a download in a standalone manner independent of the DownloadManager instance. It is suitable for use in multiprocessing environment where a direct reference to DownloadManager instance is not feasible. + """ + download_initiated_time = get_utc_timestamp() + download_id = generate_uuid() + download = DescribeDownload( + job_id=job_id, + download_id=download_id, + download_initiated_time=download_initiated_time, + redownload=redownload, + ) + download_record = Download(**download.dict()) + db_session.add(download_record) + db_session.commit() + download_queue.put(download) + + +class DownloadRecordManager: + def __init__(self, db_url): + self.session = create_session(db_url) + + def put(self, download: DescribeDownload): + with self.session() as session: + download = Download(**download.dict()) + session.add(download) + session.commit() + + def get(self, download_id: str) -> Optional[DescribeDownload]: + with self.session() as session: + download = session.query(Download).filter(Download.download_id == download_id).first() + + if download: + return DescribeDownload.from_orm(download) + else: + return None + + def get_downloads(self) -> List[DescribeDownload]: + with self.session() as session: + return session.query(Download).order_by(Download.download_initiated_time).all() + + def delete_download(self, download_id: str): + with self.session() as session: + session.query(Download).filter(Download.download_id == download_id).delete() + session.commit() + + def delete_job_downloads(self, job_id: str): + with self.session() as session: + session.query(Download).filter(Download.job_id == job_id).delete() + session.commit() + + +class DownloadManager: + def __init__(self, db_url: str): + self.record_manager = DownloadRecordManager(db_url=db_url) + self.queue = Queue() + + def initiate_download(self, job_id: str, redownload: bool): + with self.record_manager.session() as session: + initiate_download_standalone( + job_id=job_id, download_queue=self.queue, db_session=session, redownload=redownload + ) + + def delete_download(self, download_id: str): + self.record_manager.delete_download(download_id) + + def delete_job_downloads(self, job_id: str): + self.record_manager.delete_job_downloads(job_id) + + def populate_queue(self): + downloads = self.record_manager.get_downloads() + for download in downloads: + self.queue.put(download) diff --git a/jupyter_scheduler/download_runner.py b/jupyter_scheduler/download_runner.py new file mode 100644 index 00000000..0e284e3c --- /dev/null +++ b/jupyter_scheduler/download_runner.py @@ -0,0 +1,57 @@ +import asyncio + +import traitlets +from jupyter_server.transutils import _i18n +from traitlets.config import LoggingConfigurable + +from jupyter_scheduler.download_manager import DownloadManager +from jupyter_scheduler.job_files_manager import JobFilesManager + + +class BaseDownloadRunner(LoggingConfigurable): + """Base download runner, this class's start method is called + at the start of jupyter server, and is responsible for + polling for downloads to download. + """ + + def __init__(self, config=None, **kwargs): + super().__init__(config=config) + + downloads_poll_interval = traitlets.Integer( + default_value=3, + config=True, + help=_i18n( + "The interval in seconds that the download runner polls for downloads to download." + ), + ) + + def start(self): + raise NotImplementedError("Must be implemented by subclass") + + +class DownloadRunner(BaseDownloadRunner): + """Default download runner that maintains a record and a queue of initiated downloads , and polls the queue every `poll_interval` seconds + for downloads to download. + """ + + def __init__( + self, download_manager: DownloadManager, job_files_manager: JobFilesManager, config=None + ): + super().__init__(config=config) + self.download_manager = download_manager + self.job_files_manager = job_files_manager + + async def process_download_queue(self): + while not self.download_manager.queue.empty(): + download = self.download_manager.queue.get() + download_record = self.download_manager.record_manager.get(download.download_id) + if not download_record: + continue + await self.job_files_manager.copy_from_staging(download.job_id, download.redownload) + self.download_manager.delete_download(download.download_id) + + async def start(self): + self.download_manager.populate_queue() + while True: + await self.process_download_queue() + await asyncio.sleep(self.downloads_poll_interval) diff --git a/jupyter_scheduler/executors.py b/jupyter_scheduler/executors.py index 7e1a9974..97fd4423 100644 --- a/jupyter_scheduler/executors.py +++ b/jupyter_scheduler/executors.py @@ -11,6 +11,7 @@ import nbformat from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor +from jupyter_scheduler.download_manager import initiate_download_standalone from jupyter_scheduler.models import DescribeJob, JobFeature, Status from jupyter_scheduler.orm import Job, create_session from jupyter_scheduler.parameterize import add_parameters @@ -29,11 +30,19 @@ class ExecutionManager(ABC): _model = None _db_session = None - def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]): + def __init__( + self, + job_id: str, + root_dir: str, + db_url: str, + staging_paths: Dict[str, str], + download_queue, + ): self.job_id = job_id self.staging_paths = staging_paths self.root_dir = root_dir self.db_url = db_url + self.download_queue = download_queue @property def model(self): @@ -143,6 +152,13 @@ def execute(self): finally: self.add_side_effects_files(staging_dir) self.create_output_files(job, nb) + with self.db_session() as session: + initiate_download_standalone( + job_id=job.job_id, + download_queue=self.download_queue, + db_session=session, + redownload=True, + ) def add_side_effects_files(self, staging_dir: str): """Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files""" diff --git a/jupyter_scheduler/extension.py b/jupyter_scheduler/extension.py index 1a4ba373..d2a5e0f8 100644 --- a/jupyter_scheduler/extension.py +++ b/jupyter_scheduler/extension.py @@ -1,10 +1,13 @@ import asyncio +import multiprocessing from jupyter_core.paths import jupyter_data_dir from jupyter_server.extension.application import ExtensionApp from jupyter_server.transutils import _i18n from traitlets import Bool, Type, Unicode, default +from jupyter_scheduler.download_manager import DownloadManager +from jupyter_scheduler.download_runner import DownloadRunner from jupyter_scheduler.orm import create_tables from .handlers import ( @@ -67,27 +70,48 @@ def _db_url_default(self): ) def initialize_settings(self): + # Forces new processes to not be forked on Linux. + # This is necessary because `asyncio.get_event_loop()` is bugged in + # forked processes in Python versions below 3.12. This method is + # called by `jupyter_core` by `nbconvert` in the default executor. + + # See: https://github.com/python/cpython/issues/66285 + # See also: https://github.com/jupyter/jupyter_core/pull/362 + multiprocessing.set_start_method("spawn", force=True) + super().initialize_settings() create_tables(self.db_url, self.drop_tables) environments_manager = self.environment_manager_class() + download_manager = DownloadManager(db_url=self.db_url) + scheduler = self.scheduler_class( root_dir=self.serverapp.root_dir, environments_manager=environments_manager, db_url=self.db_url, + download_manager=download_manager, config=self.config, ) job_files_manager = self.job_files_manager_class(scheduler=scheduler) + download_runner = DownloadRunner( + download_manager=download_manager, job_files_manager=job_files_manager + ) + self.settings.update( environments_manager=environments_manager, scheduler=scheduler, job_files_manager=job_files_manager, + initiate_download=download_manager.initiate_download, ) if scheduler.task_runner: loop = asyncio.get_event_loop() loop.create_task(scheduler.task_runner.start()) + + if download_runner: + loop = asyncio.get_event_loop() + loop.create_task(download_runner.start()) diff --git a/jupyter_scheduler/handlers.py b/jupyter_scheduler/handlers.py index 8e773b75..f0827f0a 100644 --- a/jupyter_scheduler/handlers.py +++ b/jupyter_scheduler/handlers.py @@ -395,20 +395,20 @@ def get(self): class FilesDownloadHandler(ExtensionHandlerMixin, APIHandler): - _job_files_manager = None + _initiate_download = None @property - def job_files_manager(self): - if not self._job_files_manager: - self._job_files_manager = self.settings.get("job_files_manager", None) + def initiate_download(self): + if not self._initiate_download: + self._initiate_download = self.settings.get("initiate_download", None) - return self._job_files_manager + return self._initiate_download @authenticated async def get(self, job_id): redownload = self.get_query_argument("redownload", False) try: - await self.job_files_manager.copy_from_staging(job_id=job_id, redownload=redownload) + self.initiate_download(job_id, redownload) except Exception as e: self.log.exception(e) raise HTTPError(500, str(e)) from e diff --git a/jupyter_scheduler/models.py b/jupyter_scheduler/models.py index 38e240e0..b80c9556 100644 --- a/jupyter_scheduler/models.py +++ b/jupyter_scheduler/models.py @@ -295,3 +295,16 @@ class JobFeature(str, Enum): output_filename_template = "output_filename_template" stop_job = "stop_job" delete_job = "delete_job" + + +class DescribeDownload(BaseModel): + job_id: str + download_id: str + download_initiated_time: int + redownload: bool + + class Config: + orm_mode = True + + def __str__(self) -> str: + return self.json() diff --git a/jupyter_scheduler/orm.py b/jupyter_scheduler/orm.py index 24a915b3..c3480ede 100644 --- a/jupyter_scheduler/orm.py +++ b/jupyter_scheduler/orm.py @@ -1,5 +1,4 @@ import json -import os from sqlite3 import OperationalError from uuid import uuid4 @@ -112,6 +111,14 @@ class JobDefinition(CommonColumns, Base): active = Column(Boolean, default=True) +class Download(Base): + __tablename__ = "downloads" + job_id = Column(String(36), primary_key=True) + download_id = Column(String(36), primary_key=True) + download_initiated_time = Column(Integer) + redownload = Column(Boolean, default=False) + + def create_tables(db_url, drop_tables=False): engine = create_engine(db_url) try: diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 867034c6..a3861c73 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -1,7 +1,7 @@ -import multiprocessing as mp import os import random import shutil +from multiprocessing import Process from typing import Dict, List, Optional, Type, Union import fsspec @@ -15,6 +15,7 @@ from traitlets import Unicode, default from traitlets.config import LoggingConfigurable +from jupyter_scheduler.download_manager import DownloadManager from jupyter_scheduler.environments import EnvironmentManager from jupyter_scheduler.exceptions import ( IdempotencyTokenError, @@ -404,6 +405,7 @@ def __init__( root_dir: str, environments_manager: Type[EnvironmentManager], db_url: str, + download_manager: DownloadManager, config=None, **kwargs, ): @@ -413,6 +415,7 @@ def __init__( self.db_url = db_url if self.task_runner_class: self.task_runner = self.task_runner_class(scheduler=self, config=config) + self.download_manager = download_manager @property def db_session(self): @@ -478,20 +481,13 @@ def create_job(self, model: CreateJob) -> str: else: self.copy_input_file(model.input_uri, staging_paths["input"]) - # The MP context forces new processes to not be forked on Linux. - # This is necessary because `asyncio.get_event_loop()` is bugged in - # forked processes in Python versions below 3.12. This method is - # called by `jupyter_core` by `nbconvert` in the default executor. - # - # See: https://github.com/python/cpython/issues/66285 - # See also: https://github.com/jupyter/jupyter_core/pull/362 - mp_ctx = mp.get_context("spawn") - p = mp_ctx.Process( + p = Process( target=self.execution_manager_class( job_id=job.job_id, staging_paths=staging_paths, root_dir=self.root_dir, db_url=self.db_url, + download_queue=self.download_manager.queue, ).process ) p.start() @@ -583,6 +579,7 @@ def delete_job(self, job_id: str): session.query(Job).filter(Job.job_id == job_id).delete() session.commit() + self.download_manager.delete_job_downloads(job_id) def stop_job(self, job_id): with self.db_session() as session: diff --git a/jupyter_scheduler/tests/mocks.py b/jupyter_scheduler/tests/mocks.py index 9a60e6b7..304915dd 100644 --- a/jupyter_scheduler/tests/mocks.py +++ b/jupyter_scheduler/tests/mocks.py @@ -1,5 +1,8 @@ +from multiprocessing import Queue from typing import Dict, List +from unittest.mock import Mock +from jupyter_scheduler.download_manager import DownloadManager from jupyter_scheduler.environments import EnvironmentManager from jupyter_scheduler.executors import ExecutionManager from jupyter_scheduler.models import JobFeature, RuntimeEnvironment, UpdateJobDefinition @@ -73,3 +76,8 @@ def pause_jobs(self, job_definition_id: str): def resume_jobs(self, job_definition_id: str): pass + + +class MockDownloadManager(DownloadManager): + def __init__(self, db_url: str): + self.queue = Queue() diff --git a/jupyter_scheduler/tests/test_execution_manager.py b/jupyter_scheduler/tests/test_execution_manager.py index a9393eb9..fad3f3fe 100644 --- a/jupyter_scheduler/tests/test_execution_manager.py +++ b/jupyter_scheduler/tests/test_execution_manager.py @@ -7,6 +7,7 @@ from conftest import DB_URL from jupyter_scheduler.executors import DefaultExecutionManager from jupyter_scheduler.orm import Job +from jupyter_scheduler.tests.mocks import MockDownloadManager JOB_ID = "69856f4e-ce94-45fd-8f60-3a587457fce7" NOTEBOOK_NAME = "side_effects.ipynb" @@ -30,11 +31,13 @@ def load_job(jp_scheduler_db): def test_add_side_effects_files(jp_scheduler_db, load_job): + download_manager = MockDownloadManager(DB_URL) manager = DefaultExecutionManager( job_id=JOB_ID, root_dir=str(NOTEBOOK_DIR), db_url=DB_URL, staging_paths={"input": str(NOTEBOOK_PATH)}, + download_queue=download_manager.queue, ) manager.add_side_effects_files(str(NOTEBOOK_DIR)) diff --git a/jupyter_scheduler/tests/test_job_files_manager.py b/jupyter_scheduler/tests/test_job_files_manager.py index e6fcb5d6..a95876f2 100644 --- a/jupyter_scheduler/tests/test_job_files_manager.py +++ b/jupyter_scheduler/tests/test_job_files_manager.py @@ -58,6 +58,7 @@ async def test_copy_from_staging(): redownload=False, include_staging_files=None, ) + mock_process.assert_called_once() HERE = Path(__file__).parent.resolve()