@@ -110,6 +110,50 @@ def save_checkpoint(
110110 f"(writing took { write_timer .sum } seconds)"
111111 )
112112
113+ # See if there's any older checkpoints to delete after saving a new one.
114+ # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both).
115+ delete_old_checkpoint_files (cfg , end_of_epoch , suffix )
116+
117+
118+ def delete_old_checkpoint_files (cfg : CheckpointConfig , end_of_epoch : bool , suffix : str ):
119+ if not end_of_epoch and cfg .keep_last_updates > 0 :
120+ # remove old checkpoints; checkpoints are sorted in descending order
121+ checkpoints = checkpoint_paths (
122+ cfg .save_dir , pattern = r"checkpoint_\d+_(\d+){}\.pt" .format (suffix )
123+ )
124+ for old_chk in checkpoints [cfg .keep_last_updates :]:
125+ if os .path .lexists (old_chk ):
126+ os .remove (old_chk )
127+
128+ if cfg .keep_last_epochs > 0 :
129+ # remove old epoch checkpoints; checkpoints are sorted in descending order
130+ checkpoints = checkpoint_paths (
131+ cfg .save_dir , pattern = r"checkpoint(\d+){}\.pt" .format (suffix )
132+ )
133+ for old_chk in checkpoints [cfg .keep_last_epochs :]:
134+ if os .path .lexists (old_chk ):
135+ os .remove (old_chk )
136+
137+
138+ # Reference:
139+ # https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538
140+ def checkpoint_paths (path , pattern = r"checkpoint(\d+)\.pt" ):
141+ """Retrieves all checkpoints found in `path` directory.
142+ Checkpoints are identified by matching filename to the specified pattern. If
143+ the pattern contains groups, the result will be sorted by the first group in
144+ descending order.
145+ """
146+ pt_regexp = re .compile (pattern )
147+ files = os .listdir (path )
148+
149+ entries = []
150+ for i , f in enumerate (files ):
151+ m = pt_regexp .fullmatch (f )
152+ if m is not None :
153+ idx = float (m .group (1 )) if len (m .groups ()) > 0 else i
154+ entries .append ((idx , m .group (0 )))
155+ return [os .path .join (path , x [1 ]) for x in sorted (entries , reverse = True )]
156+
113157
114158def load_checkpoint (cfg : CheckpointConfig , trainer , ** passthrough_args ):
115159 """
0 commit comments