Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit e8428c1

Browse files
committed
add back checkpoint deletion logic, configurable via keep_last_updates or keep_last_epochs
1 parent 35bab2c commit e8428c1

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

metaseq/checkpoint_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,50 @@ def save_checkpoint(
110110
f"(writing took {write_timer.sum} seconds)"
111111
)
112112

113+
# See if there's any older checkpoints to delete after saving a new one.
114+
# Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both).
115+
delete_old_checkpoint_files(cfg, end_of_epoch, suffix)
116+
117+
118+
def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str):
119+
if not end_of_epoch and cfg.keep_last_updates > 0:
120+
# remove old checkpoints; checkpoints are sorted in descending order
121+
checkpoints = checkpoint_paths(
122+
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
123+
)
124+
for old_chk in checkpoints[cfg.keep_last_updates :]:
125+
if os.path.lexists(old_chk):
126+
os.remove(old_chk)
127+
128+
if cfg.keep_last_epochs > 0:
129+
# remove old epoch checkpoints; checkpoints are sorted in descending order
130+
checkpoints = checkpoint_paths(
131+
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
132+
)
133+
for old_chk in checkpoints[cfg.keep_last_epochs :]:
134+
if os.path.lexists(old_chk):
135+
os.remove(old_chk)
136+
137+
138+
# Reference:
139+
# https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538
140+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
141+
"""Retrieves all checkpoints found in `path` directory.
142+
Checkpoints are identified by matching filename to the specified pattern. If
143+
the pattern contains groups, the result will be sorted by the first group in
144+
descending order.
145+
"""
146+
pt_regexp = re.compile(pattern)
147+
files = os.listdir(path)
148+
149+
entries = []
150+
for i, f in enumerate(files):
151+
m = pt_regexp.fullmatch(f)
152+
if m is not None:
153+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
154+
entries.append((idx, m.group(0)))
155+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
156+
113157

114158
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
115159
"""

0 commit comments

Comments
 (0)