@@ -371,6 +371,25 @@ def copy_input_file(self, input_uri: str, copy_to_path: str):
371
371
with fsspec .open (copy_to_path , "wb" ) as output_file :
372
372
output_file .write (input_file .read ())
373
373
374
+ def copy_input_folder (self , input_uri : str , nb_copy_to_path : str ):
375
+ """Copies the input file along with the input directory to the staging directory"""
376
+ input_dir_path = os .path .dirname (os .path .join (self .root_dir , input_uri ))
377
+ staging_dir = os .path .dirname (nb_copy_to_path )
378
+
379
+ # Copy the input file
380
+ self .copy_input_file (input_uri , nb_copy_to_path )
381
+
382
+ # Copy the rest of the input folder excluding the input file
383
+ for item in os .listdir (input_dir_path ):
384
+ source = os .path .join (input_dir_path , item )
385
+ destination = os .path .join (staging_dir , item )
386
+ if os .path .isdir (source ):
387
+ shutil .copytree (source , destination )
388
+ elif os .path .isfile (source ) and item != os .path .basename (input_uri ):
389
+ with fsspec .open (source ) as src_file :
390
+ with fsspec .open (destination , "wb" ) as output_file :
391
+ output_file .write (src_file .read ())
392
+
374
393
def create_job (self , model : CreateJob ) -> str :
375
394
if not model .job_definition_id and not self .file_exists (model .input_uri ):
376
395
raise InputUriError (model .input_uri )
@@ -401,7 +420,10 @@ def create_job(self, model: CreateJob) -> str:
401
420
session .commit ()
402
421
403
422
staging_paths = self .get_staging_paths (DescribeJob .from_orm (job ))
404
- self .copy_input_file (model .input_uri , staging_paths ["input" ])
423
+ if model .package_input_folder :
424
+ self .copy_input_folder (model .input_uri , staging_paths ["input" ])
425
+ else :
426
+ self .copy_input_file (model .input_uri , staging_paths ["input" ])
405
427
406
428
# The MP context forces new processes to not be forked on Linux.
407
429
# This is necessary because `asyncio.get_event_loop()` is bugged in
@@ -541,7 +563,10 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:
541
563
job_definition_id = job_definition .job_definition_id
542
564
543
565
staging_paths = self .get_staging_paths (DescribeJobDefinition .from_orm (job_definition ))
544
- self .copy_input_file (model .input_uri , staging_paths ["input" ])
566
+ if model .package_input_folder :
567
+ self .copy_input_folder (model .input_uri , staging_paths ["input" ])
568
+ else :
569
+ self .copy_input_file (model .input_uri , staging_paths ["input" ])
545
570
546
571
if self .task_runner and job_definition .schedule :
547
572
self .task_runner .add_job_definition (job_definition_id )
@@ -690,6 +715,10 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
690
715
691
716
staging_paths ["input" ] = os .path .join (self .staging_path , id , model .input_filename )
692
717
718
+ if model .package_input_folder :
719
+ notebook_dir = os .path .dirname (staging_paths ["input" ])
720
+ staging_paths ["input_dir" ] = notebook_dir
721
+
693
722
return staging_paths
694
723
695
724
0 commit comments