@@ -181,6 +181,8 @@ def __init__(
181
181
config_dict = config .as_dict ()
182
182
self .config_path_original = str (config_dict ["config" ][0 ])
183
183
_ , self .config_file_original = os .path .split (self .config_path_original )
184
+
185
+ self .deploy_config_file = config_dict ["deploy_config_file" ]
184
186
185
187
def setup (self , trainer : Trainer , pl_module : LightningModule , stage : str ) -> None :
186
188
if self .already_saved :
@@ -212,26 +214,28 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
212
214
# save only on rank zero to avoid race conditions.
213
215
# the `log_dir` needs to be created as we rely on the logger to do it usually
214
216
# but it hasn't logged anything at this point
215
- fs .makedirs (log_dir , exist_ok = True )
216
- self .parser .save (
217
- self .config , config_path , skip_none = True , overwrite = self .overwrite , multifile = self .multifile
218
- )
217
+ if self .deploy_config_file :
218
+ fs .makedirs (log_dir , exist_ok = True )
219
+ self .parser .save (
220
+ self .config , config_path , skip_none = True , overwrite = self .overwrite , multifile = self .multifile
221
+ )
219
222
220
223
if trainer .is_global_zero :
221
- # also save the config that will be deployed
222
- config_name , config_ext = os .path .splitext (self .config_filename )
223
- config_name += "_deploy"
224
- config_name += config_ext
225
- config_path = os .path .join (log_dir , config_name )
226
- self .parser .save (
227
- self .config ,
228
- config_path ,
229
- format = "deploy_config" ,
230
- skip_none = True ,
231
- overwrite = self .overwrite ,
232
- multifile = self .multifile ,
233
- )
234
- self .already_saved = True
224
+ if self .deploy_config_file :
225
+ # also save the config that will be deployed
226
+ config_name , config_ext = os .path .splitext (self .config_filename )
227
+ config_name += "_deploy"
228
+ config_name += config_ext
229
+ config_path = os .path .join (log_dir , config_name )
230
+ self .parser .save (
231
+ self .config ,
232
+ config_path ,
233
+ format = "deploy_config" ,
234
+ skip_none = True ,
235
+ overwrite = self .overwrite ,
236
+ multifile = self .multifile ,
237
+ )
238
+ self .already_saved = True
235
239
236
240
config_path_dir , config_path_file = os .path .split (config_path )
237
241
self .config_path_new = os .path .join (config_path_dir , self .config_file_original )
@@ -294,6 +298,7 @@ class MyLightningCLI(LightningCLI):
294
298
def add_arguments_to_parser (self , parser : LightningArgumentParser ) -> None :
295
299
parser .add_argument ("--predict_output_dir" , default = None )
296
300
parser .add_argument ("--out_dtype" , default = "int16" )
301
+ parser .add_argument ("--deploy_config_file" , type = bool , default = True )
297
302
298
303
# parser.set_defaults({"trainer.enable_checkpointing": False})
299
304
@@ -324,6 +329,10 @@ def instantiate_classes(self) -> None:
324
329
if hasattr (config , "out_dtype" ):
325
330
self .trainer .out_dtype = config .out_dtype
326
331
332
+ if hasattr (config , "deploy_config_file" ):
333
+ self .trainer .deploy_config = config .deploy_config_file
334
+
335
+
327
336
def build_lightning_cli (
328
337
args : ArgsType = None ,
329
338
run = True , # noqa: FBT002
0 commit comments