Skip to content

Commit b01cff9

Browse files
committed
Remove DownloadTask data class, use DescribeDownload for both queue and db records
1 parent b5f98fe commit b01cff9

File tree

4 files changed

+36
-65
lines changed

4 files changed

+36
-65
lines changed

jupyter_scheduler/download_manager.py

Lines changed: 15 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,42 @@
1-
from dataclasses import dataclass
2-
from datetime import datetime
31
from multiprocessing import Queue
42
from typing import List, Optional
53

6-
from jupyter_scheduler.orm import Downloads, create_session, generate_uuid
7-
from jupyter_scheduler.pydantic_v1 import BaseModel
4+
from jupyter_scheduler.models import DescribeDownload
5+
from jupyter_scheduler.orm import Download, create_session, generate_uuid
86
from jupyter_scheduler.utils import get_utc_timestamp
97

108

11-
class DescribeDownload(BaseModel):
12-
job_id: str
13-
download_id: str
14-
download_initiated_time: int
15-
16-
class Config:
17-
orm_mode = True
18-
19-
20-
@dataclass
21-
class DownloadTask:
22-
job_id: str
23-
download_id: str
24-
download_initiated_time: int
25-
26-
def __lt__(self, other):
27-
return self.download_initiated_time < other.download_initiated_time
28-
29-
def __str__(self):
30-
download_initiated_time = datetime.fromtimestamp(self.download_initiated_time / 1e3)
31-
return f"Id: {self.job_id}, Download initiated: {download_initiated_time}"
32-
33-
349
class DownloadRecordManager:
3510
def __init__(self, db_url):
3611
self.session = create_session(db_url)
3712

3813
def put(self, download: DescribeDownload):
3914
with self.session() as session:
40-
new_download = Downloads(**download.dict())
41-
session.add(new_download)
15+
download = Download(**download.dict())
16+
session.add(download)
4217
session.commit()
4318

4419
def get(self, job_id: str) -> Optional[DescribeDownload]:
4520
with self.session() as session:
46-
download = session.query(Downloads).filter(Downloads.job_id == job_id).first()
21+
download = session.query(Download).filter(Download.job_id == job_id).first()
4722

4823
if download:
4924
return DescribeDownload.from_orm(download)
5025
else:
5126
return None
5227

53-
def get_tasks(self) -> List[DescribeDownload]:
28+
def get_downloads(self) -> List[DescribeDownload]:
5429
with self.session() as session:
55-
return session.query(Downloads).order_by(Downloads.download_initiated_time).all()
30+
return session.query(Download).order_by(Download.download_initiated_time).all()
5631

5732
def delete_download(self, download_id: str):
5833
with self.session() as session:
59-
session.query(Downloads).filter(Downloads.download_id == download_id).delete()
34+
session.query(Download).filter(Download.download_id == download_id).delete()
6035
session.commit()
6136

6237
def delete_job_downloads(self, job_id: str):
6338
with self.session() as session:
64-
session.query(Downloads).filter(Downloads.job_id == job_id).delete()
39+
session.query(Download).filter(Download.job_id == job_id).delete()
6540
session.commit()
6641

6742

@@ -73,18 +48,13 @@ def __init__(self, db_url: str):
7348
def download_from_staging(self, job_id: str):
7449
download_initiated_time = get_utc_timestamp()
7550
download_id = generate_uuid()
76-
download_cache = DescribeDownload(
77-
job_id=job_id,
78-
download_id=download_id,
79-
download_initiated_time=download_initiated_time,
80-
)
81-
self.record_manager.put(download_cache)
82-
download_task = DownloadTask(
51+
download = DescribeDownload(
8352
job_id=job_id,
8453
download_id=download_id,
8554
download_initiated_time=download_initiated_time,
8655
)
87-
self.queue.put(download_task)
56+
self.record_manager.put(download)
57+
self.queue.put(download)
8858

8959
def delete_download(self, download_id: str):
9060
self.record_manager.delete_download(download_id)
@@ -93,11 +63,6 @@ def delete_job_downloads(self, job_id: str):
9363
self.record_manager.delete_job_downloads(job_id)
9464

9565
def populate_queue(self):
96-
tasks = self.record_manager.get_tasks()
97-
for task in tasks:
98-
download_task = DownloadTask(
99-
job_id=task.job_id,
100-
download_id=task.download_id,
101-
download_initiated_time=task.download_initiated_time,
102-
)
103-
self.queue.put(download_task)
66+
downloads = self.record_manager.get_downloads()
67+
for download in downloads:
68+
self.queue.put(download)

jupyter_scheduler/executors.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
import multiprocessing
32
import os
43
import shutil
54
import tarfile
@@ -12,8 +11,8 @@
1211
import nbformat
1312
from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor
1413

15-
from jupyter_scheduler.download_manager import DescribeDownload, Downloads, DownloadTask
16-
from jupyter_scheduler.models import DescribeJob, JobFeature, JobFile, Status
14+
from jupyter_scheduler.download_manager import DescribeDownload, Download
15+
from jupyter_scheduler.models import DescribeJob, JobFeature, Status
1716
from jupyter_scheduler.orm import Job, create_session, generate_uuid
1817
from jupyter_scheduler.parameterize import add_parameters
1918
from jupyter_scheduler.utils import get_utc_timestamp
@@ -157,26 +156,21 @@ def execute(self):
157156
output, _ = cls().from_notebook_node(nb)
158157
with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f:
159158
f.write(output)
160-
self.download_from_staging(job.job_id)
159+
self._download_from_staging(job.job_id)
161160

162-
def download_from_staging(self, job_id: str):
161+
def _download_from_staging(self, job_id: str):
163162
download_initiated_time = get_utc_timestamp()
164163
download_id = generate_uuid()
165-
download_cache = DescribeDownload(
164+
download = DescribeDownload(
166165
job_id=job_id,
167166
download_id=download_id,
168167
download_initiated_time=download_initiated_time,
169168
)
170169
with self.db_session() as session:
171-
new_download = Downloads(**download_cache.dict())
172-
session.add(new_download)
170+
download_record = Download(**download.dict())
171+
session.add(download_record)
173172
session.commit()
174-
download_task = DownloadTask(
175-
job_id=job_id,
176-
download_id=download_id,
177-
download_initiated_time=download_initiated_time,
178-
)
179-
self.download_queue.put(download_task)
173+
self.download_queue.put(download)
180174

181175
def add_side_effects_files(self, staging_dir):
182176
"""Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""

jupyter_scheduler/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,15 @@ class JobFeature(str, Enum):
294294
output_filename_template = "output_filename_template"
295295
stop_job = "stop_job"
296296
delete_job = "delete_job"
297+
298+
299+
class DescribeDownload(BaseModel):
300+
job_id: str
301+
download_id: str
302+
download_initiated_time: int
303+
304+
class Config:
305+
orm_mode = True
306+
307+
def __str__(self) -> str:
308+
return self.json()

jupyter_scheduler/orm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class JobDefinition(CommonColumns, Base):
111111
active = Column(Boolean, default=True)
112112

113113

114-
class Downloads(Base):
114+
class Download(Base):
115115
__tablename__ = "downloads"
116116
job_id = Column(String(36), primary_key=True)
117117
download_id = Column(String(36), primary_key=True)

0 commit comments

Comments
 (0)