@@ -130,9 +130,10 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): #
130
130
save_prediction (prediction , file_name , output_dir , dtype = trainer .out_dtype )
131
131
132
132
133
- def clean_config_for_deployment_and_dump (config : dict [str , Any ], clean : bool = False ):
133
+ def clean_config_for_deployment_and_dump (config : dict [str , Any ]):
134
134
deploy_config = deepcopy (config )
135
- if clean :
135
+
136
+ if config ["clean_config" ]:
136
137
## General
137
138
# drop ckpt_path
138
139
deploy_config .pop ("ckpt_path" , None )
@@ -176,7 +177,7 @@ def __init__(
176
177
):
177
178
super ().__init__ (parser , config , config_filename , overwrite , multifile , save_to_log_dir )
178
179
set_dumper ("deploy_config" , clean_config_for_deployment_and_dump )
179
-
180
+
180
181
def setup (self , trainer : Trainer , pl_module : LightningModule , stage : str ) -> None :
181
182
if self .already_saved :
182
183
return
@@ -285,6 +286,7 @@ class MyLightningCLI(LightningCLI):
285
286
def add_arguments_to_parser (self , parser : LightningArgumentParser ) -> None :
286
287
parser .add_argument ("--predict_output_dir" , default = None )
287
288
parser .add_argument ("--out_dtype" , default = "int16" )
289
+ parser .add_argument ("--clean_config" , type = bool , default = False )
288
290
289
291
# parser.set_defaults({"trainer.enable_checkpointing": False})
290
292
@@ -315,6 +317,9 @@ def instantiate_classes(self) -> None:
315
317
if hasattr (config , "out_dtype" ):
316
318
self .trainer .out_dtype = config .out_dtype
317
319
320
+ if hasattr (config , "clean_config" ):
321
+ self .trainer .clean_config = config .clean_config
322
+
318
323
def build_lightning_cli (
319
324
args : ArgsType = None ,
320
325
run = True , # noqa: FBT002
0 commit comments