Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 0c5db28

Browse files
committed
stub in where symlinking/copying might happen
1 parent 536b07c commit 0c5db28

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

metaseq/checkpoint_utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,9 @@ def save_checkpoint(
9595
extra_state,
9696
training_finished=training_finished,
9797
async_callback_fn=async_callback_fn if save_to_NFS else None,
98+
files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None,
9899
)
99100

100-
# if len(checkpoints) > 1:
101-
# # Create symlink between identical checkpoints (differing in naming for epoch/update/last).
102-
# for other_checkpoint in checkpoints[1:]:
103-
# if PathManager.islink(other_checkpoint):
104-
# PathManager.rm(other_checkpoint)
105-
# assert PathManager.symlink(
106-
# checkpoints[0], other_checkpoint
107-
# ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}"
108-
109101
write_timer.stop()
110102
logger.info(
111103
f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) "

metaseq/cli/train.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _checkpoint_add_directory(basename):
492492
return m[1], f"checkpoint{m[3]}"
493493

494494

495-
def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
495+
def post_checkpoint_callback(cfg, num_updates, training_finished, filename, files_to_symlink_to):
496496
if cfg.checkpoint.cloud_upload_path is not None:
497497
if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path:
498498
azcopy_logs = filename + "_azcopy_logs"
@@ -521,6 +521,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
521521
f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}"
522522
)
523523
os.remove(filename)
524+
525+
# TODO[Susan]: Add symlink logic here? Check what cloud_upload_path is being used for Uriel's jobs.
526+
524527
elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"):
525528
path, basename = os.path.split(filename)
526529
checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename)
@@ -560,6 +563,8 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
560563
)
561564
os.remove(filename)
562565

566+
# TODO[Susan]: Add symlink logic here.
567+
563568
# Start running evals on uploaded checkpoint
564569
nfs_evaluation(
565570
cfg,
@@ -583,6 +588,16 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
583588
except (FileNotFoundError, AssertionError) as e:
584589
logger.info(f"could not upload {filename}: {e}")
585590

591+
# TODO[Susan]: Add symlink logic here.
592+
593+
# if files_to_symlink_to is not None and len(files_to_symlink_to) > 1:
594+
# for other_checkpoint in files_to_symlink_to:
595+
# if PathManager.islink(other_checkpoint):
596+
# PathManager.rm(other_checkpoint)
597+
# assert PathManager.symlink(
598+
# filename, other_checkpoint
599+
# ), f"Failed to symlink {filename} to {other_checkpoint}"
600+
586601

587602
def nfs_evaluation(
588603
cfg, num_updates, training_finished, checkpoint_dir, destination_checkpoints_dir

metaseq/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]:
422422
return state_dicts
423423

424424
def save_checkpoint(
425-
self, filename, extra_state, training_finished=False, async_callback_fn=None
425+
self, filename, extra_state, training_finished=False, async_callback_fn=None, files_to_symlink_to=None
426426
):
427427
"""Save all training state in a checkpoint file."""
428428

@@ -446,7 +446,7 @@ def save_checkpoint(
446446
def perform_save():
447447
try:
448448
logger.info(f"Beginning asynchronous torch.save to {filename}")
449-
async_callback_fn(filename)
449+
async_callback_fn(filename, files_to_symlink_to)
450450
logger.info(f"Asynchronous torch.save to {filename} complete.")
451451
except Exception as e:
452452
logger.exception(f"Asynchronous save failed: {e}")

0 commit comments

Comments
 (0)