2
2
3
3
import logging # noqa: I001
4
4
import os
5
+ import shutil
5
6
import warnings
6
7
from datetime import timedelta
7
8
from pathlib import Path
@@ -176,6 +177,11 @@ def __init__(
176
177
super ().__init__ (parser , config , config_filename , overwrite , multifile , save_to_log_dir )
177
178
set_dumper ("deploy_config" , clean_config_for_deployment_and_dump )
178
179
180
+ # Preparing information to save config file to log dir
181
+ config_dict = config .as_dict ()
182
+ self .config_path_original = str (config_dict ["config" ][0 ])
183
+ _ , self .config_file_original = os .path .split (self .config_path_original )
184
+
179
185
def setup (self , trainer : Trainer , pl_module : LightningModule , stage : str ) -> None :
180
186
if self .already_saved :
181
187
return
@@ -227,9 +233,13 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
227
233
)
228
234
self .already_saved = True
229
235
236
+ config_path_dir , config_path_file = os .path .split (config_path )
237
+ self .config_path_new = os .path .join (config_path_dir , self .config_file_original )
238
+
230
239
# broadcast so that all ranks are in sync on future calls to .setup()
231
240
self .already_saved = trainer .strategy .broadcast (self .already_saved )
232
-
241
+ # Copying config file to log dir
242
+ shutil .copyfile (self .config_path_original , self .config_path_new )
233
243
234
244
class StateDictAwareModelCheckpoint (ModelCheckpoint ):
235
245
# necessary as we wish to have one model checkpoint with only state dict and one with standard lightning checkpoints
@@ -467,4 +477,4 @@ def inference(self, file_path: Path) -> torch.Tensor:
467
477
prediction , file_name = self .inference_on_dir (
468
478
tmpdir ,
469
479
)
470
- return prediction .squeeze (0 )
480
+ return prediction .squeeze (0 )
0 commit comments