Skip to content

Commit 4aafdf6

Browse files
Directly copying the YAML config to the logdir
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent bbef2dc commit 4aafdf6

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

terratorch/cli_tools.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging # noqa: I001
44
import os
5+
import shutil
56
import warnings
67
from datetime import timedelta
78
from pathlib import Path
@@ -176,6 +177,11 @@ def __init__(
176177
super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir)
177178
set_dumper("deploy_config", clean_config_for_deployment_and_dump)
178179

180+
# Preparing information to save config file to log dir
181+
config_dict = config.as_dict()
182+
self.config_path_original = str(config_dict["config"][0])
183+
_, self.config_file_original = os.path.split(self.config_path_original)
184+
179185
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
180186
if self.already_saved:
181187
return
@@ -227,9 +233,13 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
227233
)
228234
self.already_saved = True
229235

236+
config_path_dir, config_path_file = os.path.split(config_path)
237+
self.config_path_new = os.path.join(config_path_dir, self.config_file_original)
238+
230239
# broadcast so that all ranks are in sync on future calls to .setup()
231240
self.already_saved = trainer.strategy.broadcast(self.already_saved)
232-
241+
# Copying config file to log dir
242+
shutil.copyfile(self.config_path_original, self.config_path_new)
233243

234244
class StateDictAwareModelCheckpoint(ModelCheckpoint):
235245
# necessary as we wish to have one model checkpoint with only state dict and one with standard lightning checkpoints
@@ -467,4 +477,4 @@ def inference(self, file_path: Path) -> torch.Tensor:
467477
prediction, file_name = self.inference_on_dir(
468478
tmpdir,
469479
)
470-
return prediction.squeeze(0)
480+
return prediction.squeeze(0)

0 commit comments

Comments
 (0)