Skip to content

Commit d219dd1

Browse files
The flag --deploy_config_file False eliminates the generation of config.yaml and config_deploy.yaml
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 4aafdf6 commit d219dd1

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

terratorch/cli_tools.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def __init__(
181181
config_dict = config.as_dict()
182182
self.config_path_original = str(config_dict["config"][0])
183183
_, self.config_file_original = os.path.split(self.config_path_original)
184+
185+
self.deploy_config_file = config_dict["deploy_config_file"]
184186

185187
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
186188
if self.already_saved:
@@ -212,26 +214,28 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
212214
# save only on rank zero to avoid race conditions.
213215
# the `log_dir` needs to be created as we rely on the logger to do it usually
214216
# but it hasn't logged anything at this point
215-
fs.makedirs(log_dir, exist_ok=True)
216-
self.parser.save(
217-
self.config, config_path, skip_none=True, overwrite=self.overwrite, multifile=self.multifile
218-
)
217+
if self.deploy_config_file:
218+
fs.makedirs(log_dir, exist_ok=True)
219+
self.parser.save(
220+
self.config, config_path, skip_none=True, overwrite=self.overwrite, multifile=self.multifile
221+
)
219222

220223
if trainer.is_global_zero:
221-
# also save the config that will be deployed
222-
config_name, config_ext = os.path.splitext(self.config_filename)
223-
config_name += "_deploy"
224-
config_name += config_ext
225-
config_path = os.path.join(log_dir, config_name)
226-
self.parser.save(
227-
self.config,
228-
config_path,
229-
format="deploy_config",
230-
skip_none=True,
231-
overwrite=self.overwrite,
232-
multifile=self.multifile,
233-
)
234-
self.already_saved = True
224+
if self.deploy_config_file:
225+
# also save the config that will be deployed
226+
config_name, config_ext = os.path.splitext(self.config_filename)
227+
config_name += "_deploy"
228+
config_name += config_ext
229+
config_path = os.path.join(log_dir, config_name)
230+
self.parser.save(
231+
self.config,
232+
config_path,
233+
format="deploy_config",
234+
skip_none=True,
235+
overwrite=self.overwrite,
236+
multifile=self.multifile,
237+
)
238+
self.already_saved = True
235239

236240
config_path_dir, config_path_file = os.path.split(config_path)
237241
self.config_path_new = os.path.join(config_path_dir, self.config_file_original)
@@ -294,6 +298,7 @@ class MyLightningCLI(LightningCLI):
294298
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
295299
parser.add_argument("--predict_output_dir", default=None)
296300
parser.add_argument("--out_dtype", default="int16")
301+
parser.add_argument("--deploy_config_file", type=bool, default=True)
297302

298303
# parser.set_defaults({"trainer.enable_checkpointing": False})
299304

@@ -324,6 +329,10 @@ def instantiate_classes(self) -> None:
324329
if hasattr(config, "out_dtype"):
325330
self.trainer.out_dtype = config.out_dtype
326331

332+
if hasattr(config, "deploy_config_file"):
333+
self.trainer.deploy_config = config.deploy_config_file
334+
335+
327336
def build_lightning_cli(
328337
args: ArgsType = None,
329338
run=True, # noqa: FBT002

0 commit comments

Comments
 (0)