Skip to content

Commit cd319cb

Browse files
The input argument --clean_config allows to control if the outputtd config files will be more or less verbose
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent dce5152 commit cd319cb

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

terratorch/cli_tools.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,10 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): #
130130
save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype)
131131

132132

133-
def clean_config_for_deployment_and_dump(config: dict[str, Any], clean:bool=False):
133+
def clean_config_for_deployment_and_dump(config: dict[str, Any]):
134134
deploy_config = deepcopy(config)
135-
if clean:
135+
136+
if config["clean_config"]:
136137
## General
137138
# drop ckpt_path
138139
deploy_config.pop("ckpt_path", None)
@@ -176,7 +177,7 @@ def __init__(
176177
):
177178
super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir)
178179
set_dumper("deploy_config", clean_config_for_deployment_and_dump)
179-
180+
180181
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
181182
if self.already_saved:
182183
return
@@ -285,6 +286,7 @@ class MyLightningCLI(LightningCLI):
285286
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
286287
parser.add_argument("--predict_output_dir", default=None)
287288
parser.add_argument("--out_dtype", default="int16")
289+
parser.add_argument("--clean_config", type=bool, default=False)
288290

289291
# parser.set_defaults({"trainer.enable_checkpointing": False})
290292

@@ -315,6 +317,9 @@ def instantiate_classes(self) -> None:
315317
if hasattr(config, "out_dtype"):
316318
self.trainer.out_dtype = config.out_dtype
317319

320+
if hasattr(config, "clean_config"):
321+
self.trainer.clean_config = config.clean_config
322+
318323
def build_lightning_cli(
319324
args: ArgsType = None,
320325
run=True, # noqa: FBT002

0 commit comments

Comments
 (0)