Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 03b932e

Browse files
suchenzangdavides
andauthored
Auto-delete checkpoints (#679)
* create symlink for different checkpoint names * add back checkpoint deletion logic, configurable via keep_last_updates or keep_last_epochs * fix symlink * test * stub in where symlinking/copying might happen * Add symlinks when checkpointing * Fix up handling of remote paths * More fixes for azcopy; fix broken gpu_tests; run linters * More linter fixes * Update gpu_tests to account for new symlink behavior * Lint * Lint --------- Co-authored-by: davides <[email protected]>
1 parent bc29113 commit 03b932e

File tree

5 files changed

+300
-51
lines changed

5 files changed

+300
-51
lines changed

gpu_tests/test_checkpoint_saving.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from functools import partial, partialmethod
1313
import unittest
1414
from unittest.mock import patch, Mock, MagicMock
15+
from urllib.parse import urlparse
1516
import torch
1617
from metaseq.dataclass.configs import DistributedTrainingConfig
1718
from metaseq.launcher.opt_baselines import cli_main as sweep_cli_main
@@ -81,15 +82,17 @@ def test_checkpoint_saving_and_uploading(self):
8182
self.assertEqual(file_names_saved_azure, expected_file_names)
8283
for worker_cmd in upload_events:
8384
self.assertEqual(
84-
worker_cmd["command"],
85+
worker_cmd["command"][:4],
8586
[
8687
"azcopy",
8788
"copy",
8889
"--cap-mbps",
8990
"96.0",
90-
"https://myaccount.blob.core.windows.net/test",
9191
],
9292
)
93+
self.assertTrue(
94+
os.path.basename(worker_cmd["command"][-1]) in expected_file_names
95+
)
9396
self.assertEqual(
9497
worker_cmd["checkpoint_model_dir"], common_checkpoint_model_dir
9598
)
@@ -235,26 +238,35 @@ def download_checkpoint_mock(blob_url, checkpoint_path, suffix, events):
235238

236239

237240
def subprocess_run_mock(cmd, stdout, stderr, events):
238-
# replaces subprocess.run azcopy command that uploads to azure
239-
_, checkpoint_dir, checkpoint_model_dir, checkpoint_file = cmd[4].split("/")
240-
events.append(
241-
{
242-
"type": "upload",
243-
"command": cmd[:4] + cmd[5:],
244-
"checkpoint_dir": checkpoint_dir,
245-
"checkpoint_model_dir": checkpoint_model_dir,
246-
"checkpoint_file": checkpoint_file,
247-
"file_saved_locally": os.path.exists(cmd[4]),
248-
}
249-
)
241+
source = cmd[4]
242+
dest = cmd[5]
243+
244+
# Only interested in local -> remote transfers (not asserting remote copies/aliases)
245+
if urlparse(source).scheme == "" and urlparse(dest).scheme == "https":
246+
_, checkpoint_dir, checkpoint_model_dir, checkpoint_file = source.split("/")
247+
events.append(
248+
{
249+
"type": "upload",
250+
"command": cmd[:4] + cmd[5:],
251+
"checkpoint_dir": checkpoint_dir,
252+
"checkpoint_model_dir": checkpoint_model_dir,
253+
"checkpoint_file": checkpoint_file,
254+
"file_saved_locally": os.path.exists(source),
255+
}
256+
)
250257

251258
res = Mock()
252259
res.returncode = 0
253260
return res
254261

255262

256263
def save_checkpoint_mock(
257-
self, filename, extra_state, training_finished=False, async_callback_fn=None
264+
self,
265+
filename,
266+
extra_state,
267+
training_finished=False,
268+
async_callback_fn=None,
269+
files_to_symlink_to=None,
258270
):
259271
logger = logging.getLogger("metaseq.trainer")
260272
"""Save all training state in a checkpoint file."""
@@ -278,7 +290,7 @@ def perform_save():
278290
logger.info(f"Beginning asynchronous torch.save to {filename}")
279291
torch.save(state_dict, filename)
280292
if async_callback_fn is not None:
281-
async_callback_fn(filename)
293+
async_callback_fn(filename, files_to_symlink_to)
282294
logger.info(f"Asynchronous torch.save to {filename} complete.")
283295
except Exception as e:
284296
logger.exception(f"Asynchronous save failed: {e}")

metaseq/checkpoint_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def save_checkpoint(
9595
extra_state,
9696
training_finished=training_finished,
9797
async_callback_fn=async_callback_fn if save_to_NFS else None,
98+
files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None,
9899
)
99100

100101
write_timer.stop()
@@ -103,6 +104,50 @@ def save_checkpoint(
103104
f"(writing took {write_timer.sum} seconds)"
104105
)
105106

107+
# See if there's any older checkpoints to delete after saving a new one.
108+
# Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both).
109+
delete_old_checkpoint_files(cfg, end_of_epoch, suffix)
110+
111+
112+
def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str):
113+
if not end_of_epoch and cfg.keep_last_updates > 0:
114+
# remove old checkpoints; checkpoints are sorted in descending order
115+
checkpoints = checkpoint_paths(
116+
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
117+
)
118+
for old_chk in checkpoints[cfg.keep_last_updates :]:
119+
if os.path.lexists(old_chk):
120+
os.remove(old_chk)
121+
122+
if cfg.keep_last_epochs > 0:
123+
# remove old epoch checkpoints; checkpoints are sorted in descending order
124+
checkpoints = checkpoint_paths(
125+
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
126+
)
127+
for old_chk in checkpoints[cfg.keep_last_epochs :]:
128+
if os.path.lexists(old_chk):
129+
os.remove(old_chk)
130+
131+
132+
# Reference:
133+
# https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538
134+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
135+
"""Retrieves all checkpoints found in `path` directory.
136+
Checkpoints are identified by matching filename to the specified pattern. If
137+
the pattern contains groups, the result will be sorted by the first group in
138+
descending order.
139+
"""
140+
pt_regexp = re.compile(pattern)
141+
files = os.listdir(path)
142+
143+
entries = []
144+
for i, f in enumerate(files):
145+
m = pt_regexp.fullmatch(f)
146+
if m is not None:
147+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
148+
entries.append((idx, m.group(0)))
149+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
150+
106151

107152
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
108153
"""

metaseq/cli/train.py

Lines changed: 90 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import argparse
11+
from datetime import datetime
1112
import functools
1213
import logging
1314
import math
@@ -18,6 +19,7 @@
1819
import socket
1920
import re
2021
from typing import Dict, Optional, Any, List, Tuple, Callable
22+
from urllib.parse import urlparse
2123
import warnings
2224

2325
import numpy as np
@@ -512,37 +514,51 @@ def _checkpoint_add_directory(basename):
512514
return m[1], f"checkpoint{m[3]}"
513515

514516

515-
def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
517+
def _get_basename(path):
518+
res = urlparse(path)
519+
if res.scheme:
520+
return os.path.basename(res.path)
521+
else:
522+
return os.path.basename(path)
523+
524+
525+
def _get_destination_path(path, destination):
526+
"""Calculates the destination path with handling for remote paths."""
527+
basename = _get_basename(path)
528+
res = urlparse(destination)
529+
if res.scheme:
530+
new_path = os.path.join(res.path, basename)
531+
res = res._replace(path=new_path)
532+
return res.geturl()
533+
else:
534+
return os.path.join(destination, basename)
535+
536+
537+
def post_checkpoint_callback(
538+
cfg, num_updates, training_finished, filename, files_to_symlink_to
539+
):
516540
if cfg.checkpoint.cloud_upload_path is not None:
517541
if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path:
518-
azcopy_logs = filename + "_azcopy_logs"
519-
os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10"
520-
os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs
521-
os.makedirs(azcopy_logs, exist_ok=True)
522-
logger.info(
523-
f"preparing to azcopy {filename} to {cfg.checkpoint.cloud_upload_path}; logs in {azcopy_logs}"
542+
azcopy_log_dir = os.path.dirname(filename)
543+
final_path = _get_destination_path(
544+
filename, cfg.checkpoint.cloud_upload_path
524545
)
525-
cmd = [
526-
"azcopy", # TODO(susanz): require azcopy to be installed.
527-
"copy",
528-
"--cap-mbps",
529-
"96.0",
530-
filename,
531-
cfg.checkpoint.cloud_upload_path,
532-
]
533-
res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
534-
if res.returncode != 0:
535-
print("Error: {}, azcopy failed".format(res.returncode))
536-
print("Azcopy stdout = {}".format(res.stdout))
537-
sys.exit(1)
546+
_copy_to_azure(filename, final_path, azcopy_log_dir)
547+
538548
# Delete original checkpoint on local storage
539549
# TODO make this configurable
540-
logger.info(
541-
f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}"
542-
)
543550
os.remove(filename)
551+
552+
# Azure Blob doesn't support symlinks so make full copies
553+
if files_to_symlink_to:
554+
for other_checkpoint in files_to_symlink_to:
555+
dest = _get_destination_path(
556+
other_checkpoint, cfg.checkpoint.cloud_upload_path
557+
)
558+
_copy_to_azure(final_path, dest, azcopy_log_dir)
559+
544560
elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"):
545-
path, basename = os.path.split(filename)
561+
basename = os.path.basename(filename)
546562
checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename)
547563
destination_checkpoints_dir = cfg.checkpoint.cloud_upload_path[4:]
548564
temporary_checkpoint_file = f"_{checkpoint_file}"
@@ -566,6 +582,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
566582
)
567583

568584
logger.info(f"Renaming {temporary_checkpoint_file} -> {checkpoint_file}")
585+
final_path = os.path.join(
586+
destination_checkpoints_dir, checkpoint_dir, checkpoint_file
587+
)
569588
# atomic rename _checkpointfile -> checkpointfile
570589
# this way we know that if present the checkpoint file is complete
571590
os.rename(
@@ -574,12 +593,20 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
574593
checkpoint_dir,
575594
temporary_checkpoint_file,
576595
),
577-
os.path.join(
578-
destination_checkpoints_dir, checkpoint_dir, checkpoint_file
579-
),
596+
final_path,
580597
)
581598
os.remove(filename)
582599

600+
if files_to_symlink_to:
601+
dest_dir = os.path.dirname(final_path)
602+
for other_checkpoint in files_to_symlink_to:
603+
dest = _get_destination_path(other_checkpoint, dest_dir)
604+
if PathManager.islink(dest):
605+
PathManager.rm(dest)
606+
assert PathManager.symlink(
607+
final_path, dest
608+
), f"Failed to symlink {final_path} to {dest}"
609+
583610
# Start running evals on uploaded checkpoint
584611
nfs_evaluation(
585612
cfg,
@@ -593,13 +620,18 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
593620
try:
594621
# PathManager only supports writing to S3, but this function call
595622
# can be replaced with other APIs for copying checkpoints.
596-
PathManager.copy_from_local(
597-
filename,
598-
os.path.join(
599-
cfg.checkpoint.cloud_upload_path, os.path.basename(filename)
600-
),
601-
overwrite=True,
623+
final_path = _get_destination_path(
624+
filename, cfg.checkpoint.cloud_upload_path
602625
)
626+
PathManager.copy_from_local(filename, final_path, overwrite=True)
627+
628+
# Some non-native PathHandlers don't support symlinks so default to full copies
629+
if files_to_symlink_to:
630+
for other_checkpoint in files_to_symlink_to:
631+
dest = _get_destination_path(
632+
other_checkpoint, cfg.checkpoint.cloud_upload_path
633+
)
634+
PathManager.copy(final_path, dest, overwrite=True)
603635
except (FileNotFoundError, AssertionError) as e:
604636
logger.info(f"could not upload {filename}: {e}")
605637

@@ -665,6 +697,31 @@ def nfs_evaluation(
665697
)
666698

667699

700+
def _copy_to_azure(source, destination, log_dir):
701+
# /dir/checkpoint_last.pt -> /dir/checkpoint_last.pt_azcopy_logs_2000-01-01T00_00_00
702+
basename = _get_basename(destination)
703+
timestamp = datetime.utcnow().isoformat().replace(":", "_")[:-7]
704+
azcopy_logs = os.path.join(log_dir, f"{basename}_azcopy_logs_{timestamp}")
705+
os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10"
706+
os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs
707+
os.makedirs(azcopy_logs, exist_ok=True)
708+
logger.info(f"preparing to azcopy {source} to {destination}; logs in {azcopy_logs}")
709+
cmd = [
710+
"azcopy", # TODO(susanz): require azcopy to be installed.
711+
"copy",
712+
"--cap-mbps",
713+
"96.0",
714+
source,
715+
destination,
716+
]
717+
res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
718+
if res.returncode != 0:
719+
print("Error: {}, azcopy failed".format(res.returncode))
720+
print("Azcopy stdout = {}".format(res.stdout))
721+
sys.exit(1)
722+
logger.info(f"Successfully copied {source} to {destination}")
723+
724+
668725
def _run_azcopy(cmd, stdout, stderr):
669726
return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
670727

metaseq/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,12 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]:
421421
return state_dicts
422422

423423
def save_checkpoint(
424-
self, filename, extra_state, training_finished=False, async_callback_fn=None
424+
self,
425+
filename,
426+
extra_state,
427+
training_finished=False,
428+
async_callback_fn=None,
429+
files_to_symlink_to=None,
425430
):
426431
"""Save all training state in a checkpoint file."""
427432

@@ -445,7 +450,7 @@ def save_checkpoint(
445450
def perform_save():
446451
try:
447452
logger.info(f"Beginning asynchronous torch.save to {filename}")
448-
async_callback_fn(filename)
453+
async_callback_fn(filename, files_to_symlink_to)
449454
logger.info(f"Asynchronous torch.save to {filename} complete.")
450455
except Exception as e:
451456
logger.exception(f"Asynchronous save failed: {e}")

0 commit comments

Comments
 (0)