Skip to content

Commit 4d9e18c

Browse files
committed
package input files and folders (backend)
1 parent 72125a1 commit 4d9e18c

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

jupyter_scheduler/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class CreateJob(BaseModel):
8585
name: str
8686
output_filename_template: Optional[str] = OUTPUT_FILENAME_TEMPLATE
8787
compute_type: Optional[str] = None
88+
package_input_folder: Optional[bool] = None
8889

8990
@root_validator
9091
def compute_input_filename(cls, values) -> Dict:
@@ -145,6 +146,7 @@ class DescribeJob(BaseModel):
145146
status: Status = Status.CREATED
146147
status_message: Optional[str] = None
147148
downloaded: bool = False
149+
package_input_folder: Optional[bool] = None
148150

149151
class Config:
150152
orm_mode = True
@@ -209,6 +211,7 @@ class CreateJobDefinition(BaseModel):
209211
compute_type: Optional[str] = None
210212
schedule: Optional[str] = None
211213
timezone: Optional[str] = None
214+
package_input_folder: Optional[bool] = None
212215

213216
@root_validator
214217
def compute_input_filename(cls, values) -> Dict:
@@ -234,6 +237,7 @@ class DescribeJobDefinition(BaseModel):
234237
create_time: int
235238
update_time: int
236239
active: bool
240+
package_input_folder: Optional[bool] = None
237241

238242
class Config:
239243
orm_mode = True

jupyter_scheduler/orm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class CommonColumns:
8585
output_filename_template = Column(String(256))
8686
update_time = Column(Integer, default=get_utc_timestamp, onupdate=get_utc_timestamp)
8787
create_time = Column(Integer, default=get_utc_timestamp)
88+
package_input_folder = Column(Boolean)
8889

8990

9091
class Job(CommonColumns, Base):

jupyter_scheduler/scheduler.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,25 @@ def copy_input_file(self, input_uri: str, copy_to_path: str):
371371
with fsspec.open(copy_to_path, "wb") as output_file:
372372
output_file.write(input_file.read())
373373

374+
def copy_input_folder(self, input_uri: str, nb_copy_to_path: str):
375+
"""Copies the input file along with the input directory to the staging directory"""
376+
input_dir_path = os.path.dirname(os.path.join(self.root_dir, input_uri))
377+
staging_dir = os.path.dirname(nb_copy_to_path)
378+
379+
# Copy the input file
380+
self.copy_input_file(input_uri, nb_copy_to_path)
381+
382+
# Copy the rest of the input folder excluding the input file
383+
for item in os.listdir(input_dir_path):
384+
source = os.path.join(input_dir_path, item)
385+
destination = os.path.join(staging_dir, item)
386+
if os.path.isdir(source):
387+
shutil.copytree(source, destination)
388+
elif os.path.isfile(source) and item != os.path.basename(input_uri):
389+
with fsspec.open(source) as src_file:
390+
with fsspec.open(destination, "wb") as output_file:
391+
output_file.write(src_file.read())
392+
374393
def create_job(self, model: CreateJob) -> str:
375394
if not model.job_definition_id and not self.file_exists(model.input_uri):
376395
raise InputUriError(model.input_uri)
@@ -401,7 +420,10 @@ def create_job(self, model: CreateJob) -> str:
401420
session.commit()
402421

403422
staging_paths = self.get_staging_paths(DescribeJob.from_orm(job))
404-
self.copy_input_file(model.input_uri, staging_paths["input"])
423+
if model.package_input_folder:
424+
self.copy_input_folder(model.input_uri, staging_paths["input"])
425+
else:
426+
self.copy_input_file(model.input_uri, staging_paths["input"])
405427

406428
# The MP context forces new processes to not be forked on Linux.
407429
# This is necessary because `asyncio.get_event_loop()` is bugged in
@@ -541,7 +563,10 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:
541563
job_definition_id = job_definition.job_definition_id
542564

543565
staging_paths = self.get_staging_paths(DescribeJobDefinition.from_orm(job_definition))
544-
self.copy_input_file(model.input_uri, staging_paths["input"])
566+
if model.package_input_folder:
567+
self.copy_input_folder(model.input_uri, staging_paths["input"])
568+
else:
569+
self.copy_input_file(model.input_uri, staging_paths["input"])
545570

546571
if self.task_runner and job_definition.schedule:
547572
self.task_runner.add_job_definition(job_definition_id)
@@ -690,6 +715,10 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
690715

691716
staging_paths["input"] = os.path.join(self.staging_path, id, model.input_filename)
692717

718+
if model.package_input_folder:
719+
notebook_dir = os.path.dirname(staging_paths["input"])
720+
staging_paths["input_dir"] = notebook_dir
721+
693722
return staging_paths
694723

695724

0 commit comments

Comments
 (0)