@@ -130,30 +130,31 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): #
130
130
save_prediction (prediction , file_name , output_dir , dtype = trainer .out_dtype )
131
131
132
132
133
- def clean_config_for_deployment_and_dump (config : dict [str , Any ]):
133
+ def clean_config_for_deployment_and_dump (config : dict [str , Any ], clean : bool = False ):
134
134
deploy_config = deepcopy (config )
135
- ## General
136
- # drop ckpt_path
137
- deploy_config .pop ("ckpt_path" , None )
138
- # drop checkpoints
139
- deploy_config .pop ("ModelCheckpoint" , None )
140
- deploy_config .pop ("StateDictModelCheckpoint" , None )
141
- # drop optimizer and lr sheduler
142
- deploy_config .pop ("optimizer" , None )
143
- deploy_config .pop ("lr_scheduler" , None )
144
- ## Trainer
145
- # remove logging
146
- deploy_config ["trainer" ]["logger" ] = False
147
- # remove callbacks
148
- deploy_config ["trainer" ].pop ("callbacks" , None )
149
- # remove default_root_dir
150
- deploy_config ["trainer" ].pop ("default_root_dir" , None )
151
- # set mixed precision by default for inference
152
- deploy_config ["trainer" ]["precision" ] = "16-mixed"
153
- ## Model
154
- # set pretrained to false
155
- if "model_args" in deploy_config ["model" ]["init_args" ]:
156
- deploy_config ["model" ]["init_args" ]["model_args" ]["pretrained" ] = False
135
+ if clean :
136
+ ## General
137
+ # drop ckpt_path
138
+ deploy_config .pop ("ckpt_path" , None )
139
+ # drop checkpoints
140
+ deploy_config .pop ("ModelCheckpoint" , None )
141
+ deploy_config .pop ("StateDictModelCheckpoint" , None )
142
+ # drop optimizer and lr sheduler
143
+ deploy_config .pop ("optimizer" , None )
144
+ deploy_config .pop ("lr_scheduler" , None )
145
+ ## Trainer
146
+ # remove logging
147
+ deploy_config ["trainer" ]["logger" ] = False
148
+ # remove callbacks
149
+ deploy_config ["trainer" ].pop ("callbacks" , None )
150
+ # remove default_root_dir
151
+ deploy_config ["trainer" ].pop ("default_root_dir" , None )
152
+ # set mixed precision by default for inference
153
+ deploy_config ["trainer" ]["precision" ] = "16-mixed"
154
+ ## Model
155
+ # set pretrained to false
156
+ if "model_args" in deploy_config ["model" ]["init_args" ]:
157
+ deploy_config ["model" ]["init_args" ]["model_args" ]["pretrained" ] = False
157
158
158
159
return yaml .safe_dump (deploy_config )
159
160
0 commit comments