Skip to content

Commit dce5152

Browse files
Avoiding to remove info, but the ordering is not preserved.
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 3b76391 commit dce5152

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

terratorch/cli_tools.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -130,30 +130,31 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): #
130130
save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype)
131131

132132

133-
def clean_config_for_deployment_and_dump(config: dict[str, Any]):
133+
def clean_config_for_deployment_and_dump(config: dict[str, Any], clean:bool=False):
134134
deploy_config = deepcopy(config)
135-
## General
136-
# drop ckpt_path
137-
deploy_config.pop("ckpt_path", None)
138-
# drop checkpoints
139-
deploy_config.pop("ModelCheckpoint", None)
140-
deploy_config.pop("StateDictModelCheckpoint", None)
141-
# drop optimizer and lr sheduler
142-
deploy_config.pop("optimizer", None)
143-
deploy_config.pop("lr_scheduler", None)
144-
## Trainer
145-
# remove logging
146-
deploy_config["trainer"]["logger"] = False
147-
# remove callbacks
148-
deploy_config["trainer"].pop("callbacks", None)
149-
# remove default_root_dir
150-
deploy_config["trainer"].pop("default_root_dir", None)
151-
# set mixed precision by default for inference
152-
deploy_config["trainer"]["precision"] = "16-mixed"
153-
## Model
154-
# set pretrained to false
155-
if "model_args" in deploy_config["model"]["init_args"]:
156-
deploy_config["model"]["init_args"]["model_args"]["pretrained"] = False
135+
if clean:
136+
## General
137+
# drop ckpt_path
138+
deploy_config.pop("ckpt_path", None)
139+
# drop checkpoints
140+
deploy_config.pop("ModelCheckpoint", None)
141+
deploy_config.pop("StateDictModelCheckpoint", None)
142+
# drop optimizer and lr sheduler
143+
deploy_config.pop("optimizer", None)
144+
deploy_config.pop("lr_scheduler", None)
145+
## Trainer
146+
# remove logging
147+
deploy_config["trainer"]["logger"] = False
148+
# remove callbacks
149+
deploy_config["trainer"].pop("callbacks", None)
150+
# remove default_root_dir
151+
deploy_config["trainer"].pop("default_root_dir", None)
152+
# set mixed precision by default for inference
153+
deploy_config["trainer"]["precision"] = "16-mixed"
154+
## Model
155+
# set pretrained to false
156+
if "model_args" in deploy_config["model"]["init_args"]:
157+
deploy_config["model"]["init_args"]["model_args"]["pretrained"] = False
157158

158159
return yaml.safe_dump(deploy_config)
159160

0 commit comments

Comments
 (0)