From 671b017551efb9e326baad386b3559843b2546f7 Mon Sep 17 00:00:00 2001 From: Matt Mazzola Date: Wed, 7 Jun 2023 12:12:25 -0700 Subject: [PATCH 1/5] Update checkpoint_utils --- metaseq/checkpoint_utils.py | 98 +++++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 30 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 9c751a215..e340097a0 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -11,6 +11,7 @@ import socket from typing import Any, Dict, List, Optional, Tuple import math +from concurrent.futures import ThreadPoolExecutor import torch from omegaconf import OmegaConf @@ -72,61 +73,72 @@ def save_checkpoint( save_for_updates = not end_of_epoch and (save_to_NFS or save_locally) - checkpoint_conds[f"checkpoint{epoch}{suffix}.pt"] = save_for_epoch + checkpoint_conds[f"checkpoint{epoch}_{updates}{suffix}.pt"] = save_for_epoch checkpoint_conds[f"checkpoint_{updates}{suffix}.pt"] = save_for_updates - checkpoint_conds[f"checkpoint_last{suffix}.pt"] = ( - (training_finished and cfg.save_last_checkpoint) - or save_for_epoch - or save_for_updates - ) + checkpoint_last_file_name = f"checkpoint_last{suffix}.pt" extra_state = {"train_iterator": epoch_itr.state_dict()} - checkpoints = [ - os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + checkpoint_file_paths = [ + os.path.join(cfg.save_dir, checkpoint_file_name) for checkpoint_file_name, cond in checkpoint_conds.items() if cond ] - if len(checkpoints) > 0: - if PathManager.islink(checkpoints[0]): - PathManager.rm(checkpoints[0]) + def _save_checkpoint(checkpoint_file_path: str): + if PathManager.islink(checkpoint_file_path): + PathManager.rm(checkpoint_file_path) trainer.save_checkpoint( - checkpoints[0], + checkpoint_file_path, extra_state, training_finished=training_finished, - async_callback_fn=async_callback_fn if save_to_NFS else None, - files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None, + async_callback_fn=async_callback_fn, ) write_timer.stop() logger.info( - f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) " + f"Saved checkpoint {checkpoint_file_path} (epoch {epoch} @ {updates} updates) " f"(writing took {write_timer.sum} seconds)" ) # See if there's any older checkpoints to delete after saving a new one. - # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both). - delete_old_checkpoint_files(cfg, end_of_epoch, suffix) + # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0. + async_executer = ThreadPoolExecutor(max_workers=1) + async_executer.submit(delete_old_checkpoint_files, cfg, end_of_epoch, suffix) + + if len(checkpoint_file_paths) > 0: + _save_checkpoint(checkpoint_file_paths[0]) + + if training_finished and cfg.save_last_checkpoint: + checkpoint_last_file_path = os.path.join(cfg.save_dir, checkpoint_last_file_name) + _save_checkpoint(checkpoint_last_file_path) def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str): + removed_checkpoint = False if not end_of_epoch and cfg.keep_last_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + cfg.save_dir, pattern=r"checkpoint_(\d+){}\.pt".format(suffix) ) - for old_chk in checkpoints[cfg.keep_last_updates :]: - if os.path.lexists(old_chk): - os.remove(old_chk) + for old_checkpoint in checkpoints[cfg.keep_last_updates:]: + if os.path.lexists(old_checkpoint): + logger.warning(f"Removing checkpoint {old_checkpoint} because it's older than {cfg.keep_last_updates} updates") + os.remove(old_checkpoint) + removed_checkpoint = True if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + cfg.save_dir, pattern=r"checkpoint\d+_(\d+){}\.pt".format(suffix) ) - for old_chk in checkpoints[cfg.keep_last_epochs :]: - if os.path.lexists(old_chk): - os.remove(old_chk) + for old_checkpoint in checkpoints[cfg.keep_last_epochs:]: + if os.path.lexists(old_checkpoint): + logger.warning(f"Removing checkpoint {old_checkpoint} because it's older than {cfg.keep_last_epochs} epochs") + os.remove(old_checkpoint) + removed_checkpoint = True + + if removed_checkpoint: + logger.info("Done removing old checkpoints on worker {}".format(distributed_utils.get_global_rank())) # Reference: @@ -207,11 +219,37 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): else: checkpoint_path_to_load = cfg.restore_file - if cfg.restore_file != default_restore_file and cfg.finetune_from_model: - raise ValueError( - "--finetune-from-model and --restore-file (non-default value) " - "can not be specified together: " + str(cfg) - ) + last_checkpoint = None + default_restore_file = "checkpoint_last.pt" + if PathManager.exists(cfg.save_dir): + if PathManager.exists(os.path.join(cfg.save_dir, default_restore_file.replace(".pt", suffix + ".pt"))): + last_checkpoint = os.path.join(cfg.save_dir, default_restore_file.replace(".pt", suffix + ".pt")) + elif PathManager.exists(os.path.join(cfg.save_dir, default_restore_file)): + last_checkpoint = os.path.join(cfg.save_dir, default_restore_file) + else: + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint\d*_(\d+){}\.pt".format(suffix)) + if len(checkpoints) > 0 and PathManager.exists(checkpoints[0]): + last_checkpoint = checkpoints[0] + + if last_checkpoint is None: + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint\d*_(\d+)\.pt") + if len(checkpoints) > 0 and PathManager.exists(checkpoints[0]): + last_checkpoint = checkpoints[0] + + if last_checkpoint is not None and last_checkpoint != checkpoint_path_to_load: + if PathManager.exists(last_checkpoint): + logger.warning(f"Overriding restore-file-path to load with {last_checkpoint}") + checkpoint_path_to_load = last_checkpoint + logger.info( + f"Disregarding --reset-dataloader, --reset-lr-scheduler, --reset-optimizer, --reset-meters\ + flags since we are resuming from a intermediate output checkpoint" + ) + reset_optimizer = False + reset_lr_scheduler = False + reset_meters = False + reset_dataloader = False + else: + raise ValueError(f"Checkpoint {last_checkpoint} does not exist. Incorrect value found in save_checkpoint_log.txt") # Azure logic try: From 3b924a181c4a6de4577ea037a994c58be492c945 Mon Sep 17 00:00:00 2001 From: Matt Mazzola Date: Wed, 7 Jun 2023 12:29:39 -0700 Subject: [PATCH 2/5] Improve training stop conditions and always save last checkpoint --- metaseq/cli/train.py | 58 +++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 011ca0e8a..d1e5b54ae 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -203,10 +203,10 @@ def main(cfg: DictConfig) -> None: disable_iterator_cache=True, ) - max_epoch = cfg.optimization.max_epoch or math.inf train_meter = meters.StopwatchMeter() train_meter.start() - while epoch_itr.next_epoch_idx <= max_epoch: + + while True: # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: @@ -221,6 +221,9 @@ def main(cfg: DictConfig) -> None: disable_iterator_cache=True, ) train_meter.stop() + + # make sure every process finishes before exiting... + distributed_utils.global_barrier() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) @@ -433,6 +436,8 @@ def validate_and_save( end_of_epoch: bool, was_successful_step: bool, ) -> Tuple[List[Optional[float]], bool]: + num_epoch = epoch_itr.epoch + max_epoch = cfg.optimization.max_epoch or math.inf num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf @@ -444,44 +449,47 @@ def validate_and_save( # Stopping conditions (and an additional one based on validation loss later # on) should_stop = False - if num_updates >= max_update: + + if num_epoch > max_epoch: should_stop = True - logger.info( - f"Stopping training due to " - f"num_updates: {num_updates} >= max_update: {max_update}" - ) + logger.info(f"Stopping training due to " + f"num_epoch: {num_epoch} > max_epoch: {max_epoch}") + elif num_updates > max_update: + should_stop = True + logger.info(f"Stopping training due to " + f"num_updates: {num_updates} > max_update: {max_update}") - save_locally = ( + is_epoch_save_interval = ( + end_of_epoch + and cfg.checkpoint.save_interval_epochs > 0 + and num_epoch % cfg.checkpoint.save_interval_epochs == 0 + ) + is_successful_update_local_save_interval = ( cfg.checkpoint.local_save_interval_updates > 0 and num_updates > 0 - and num_updates % cfg.checkpoint.local_save_interval_updates == 0 + and num_updates % cfg.checkpoint.local_save_interval_updates == 0 and was_successful_step ) - save_to_NFS = ( + is_successful_update_save_interval = ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 - and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates % cfg.checkpoint.save_interval_updates == 0 and was_successful_step + ) + is_successful_update_validate_interval = ( + cfg.checkpoint.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.checkpoint.validate_interval_updates == 0 and was_successful_step ) do_save = ( - ( - end_of_epoch - and cfg.checkpoint.save_interval_epochs > 0 - and epoch_itr.epoch % cfg.checkpoint.save_interval_epochs == 0 - ) - or ( - (save_locally or save_to_NFS) - and num_updates >= cfg.dataset.validate_after_updates - and was_successful_step - ) + is_epoch_save_interval + or (is_successful_update_local_save_interval or is_successful_update_save_interval) or should_stop ) do_validate = ( should_stop or ( - cfg.dataset.validate_interval_updates > 0 - and num_updates > 0 - and num_updates % cfg.dataset.validate_interval_updates == 0 - and was_successful_step + is_successful_update_validate_interval + and num_updates >= cfg.dataset.validate_after_updates ) ) and not cfg.dataset.disable_validation From 9f748d8ac2138632955ff6a3fc282fb5feab1b8b Mon Sep 17 00:00:00 2001 From: Matt Mazzola Date: Wed, 7 Jun 2023 14:05:06 -0700 Subject: [PATCH 3/5] Remove default pattern and add comment --- metaseq/checkpoint_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index e340097a0..5e430ff07 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -105,6 +105,7 @@ def _save_checkpoint(checkpoint_file_path: str): async_executer = ThreadPoolExecutor(max_workers=1) async_executer.submit(delete_old_checkpoint_files, cfg, end_of_epoch, suffix) + # If there are checkpoints to save, save the first in the list if len(checkpoint_file_paths) > 0: _save_checkpoint(checkpoint_file_paths[0]) @@ -143,7 +144,7 @@ def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffi # Reference: # https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538 -def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): +def checkpoint_paths(path, pattern=None): """Retrieves all checkpoints found in `path` directory. Checkpoints are identified by matching filename to the specified pattern. If the pattern contains groups, the result will be sorted by the first group in From 105b714d5e49e7f3985221327788d9b74a274aa0 Mon Sep 17 00:00:00 2001 From: Matt Mazzola Date: Thu, 8 Jun 2023 09:18:07 -0700 Subject: [PATCH 4/5] Revert changes to deleting checkpoints and checkpoint filenames --- metaseq/checkpoint_utils.py | 70 +++++++++---------------------------- 1 file changed, 17 insertions(+), 53 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 5e430ff07..2903f3edc 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -11,7 +11,6 @@ import socket from typing import Any, Dict, List, Optional, Tuple import math -from concurrent.futures import ThreadPoolExecutor import torch from omegaconf import OmegaConf @@ -73,7 +72,7 @@ def save_checkpoint( save_for_updates = not end_of_epoch and (save_to_NFS or save_locally) - checkpoint_conds[f"checkpoint{epoch}_{updates}{suffix}.pt"] = save_for_epoch + checkpoint_conds[f"checkpoint{epoch}{suffix}.pt"] = save_for_epoch checkpoint_conds[f"checkpoint_{updates}{suffix}.pt"] = save_for_updates checkpoint_last_file_name = f"checkpoint_last{suffix}.pt" @@ -101,9 +100,8 @@ def _save_checkpoint(checkpoint_file_path: str): ) # See if there's any older checkpoints to delete after saving a new one. - # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0. - async_executer = ThreadPoolExecutor(max_workers=1) - async_executer.submit(delete_old_checkpoint_files, cfg, end_of_epoch, suffix) + # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both). + delete_old_checkpoint_files(cfg, end_of_epoch, suffix) # If there are checkpoints to save, save the first in the list if len(checkpoint_file_paths) > 0: @@ -115,36 +113,28 @@ def _save_checkpoint(checkpoint_file_path: str): def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str): - removed_checkpoint = False if not end_of_epoch and cfg.keep_last_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_(\d+){}\.pt".format(suffix) + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) ) - for old_checkpoint in checkpoints[cfg.keep_last_updates:]: - if os.path.lexists(old_checkpoint): - logger.warning(f"Removing checkpoint {old_checkpoint} because it's older than {cfg.keep_last_updates} updates") - os.remove(old_checkpoint) - removed_checkpoint = True + for old_chk in checkpoints[cfg.keep_last_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint\d+_(\d+){}\.pt".format(suffix) + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) ) - for old_checkpoint in checkpoints[cfg.keep_last_epochs:]: - if os.path.lexists(old_checkpoint): - logger.warning(f"Removing checkpoint {old_checkpoint} because it's older than {cfg.keep_last_epochs} epochs") - os.remove(old_checkpoint) - removed_checkpoint = True - - if removed_checkpoint: - logger.info("Done removing old checkpoints on worker {}".format(distributed_utils.get_global_rank())) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) # Reference: # https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538 -def checkpoint_paths(path, pattern=None): +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): """Retrieves all checkpoints found in `path` directory. Checkpoints are identified by matching filename to the specified pattern. If the pattern contains groups, the result will be sorted by the first group in @@ -220,37 +210,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): else: checkpoint_path_to_load = cfg.restore_file - last_checkpoint = None - default_restore_file = "checkpoint_last.pt" - if PathManager.exists(cfg.save_dir): - if PathManager.exists(os.path.join(cfg.save_dir, default_restore_file.replace(".pt", suffix + ".pt"))): - last_checkpoint = os.path.join(cfg.save_dir, default_restore_file.replace(".pt", suffix + ".pt")) - elif PathManager.exists(os.path.join(cfg.save_dir, default_restore_file)): - last_checkpoint = os.path.join(cfg.save_dir, default_restore_file) - else: - checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint\d*_(\d+){}\.pt".format(suffix)) - if len(checkpoints) > 0 and PathManager.exists(checkpoints[0]): - last_checkpoint = checkpoints[0] - - if last_checkpoint is None: - checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint\d*_(\d+)\.pt") - if len(checkpoints) > 0 and PathManager.exists(checkpoints[0]): - last_checkpoint = checkpoints[0] - - if last_checkpoint is not None and last_checkpoint != checkpoint_path_to_load: - if PathManager.exists(last_checkpoint): - logger.warning(f"Overriding restore-file-path to load with {last_checkpoint}") - checkpoint_path_to_load = last_checkpoint - logger.info( - f"Disregarding --reset-dataloader, --reset-lr-scheduler, --reset-optimizer, --reset-meters\ - flags since we are resuming from a intermediate output checkpoint" - ) - reset_optimizer = False - reset_lr_scheduler = False - reset_meters = False - reset_dataloader = False - else: - raise ValueError(f"Checkpoint {last_checkpoint} does not exist. Incorrect value found in save_checkpoint_log.txt") + if cfg.restore_file != default_restore_file and cfg.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) + ) # Azure logic try: From 5e6c7e405e891b1f94db9bff4ad9ce5da32c61f9 Mon Sep 17 00:00:00 2001 From: Matt Mazzola Date: Fri, 9 Jun 2023 07:53:17 -0700 Subject: [PATCH 5/5] Manually fix formatting to align with Black standards --- metaseq/checkpoint_utils.py | 8 ++++++-- metaseq/cli/train.py | 28 +++++++++++++++++++--------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 2903f3edc..6a15450a5 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -79,7 +79,9 @@ def save_checkpoint( extra_state = {"train_iterator": epoch_itr.state_dict()} checkpoint_file_paths = [ - os.path.join(cfg.save_dir, checkpoint_file_name) for checkpoint_file_name, cond in checkpoint_conds.items() if cond + os.path.join(cfg.save_dir, checkpoint_file_name) + for checkpoint_file_name, cond in checkpoint_conds.items() + if cond ] def _save_checkpoint(checkpoint_file_path: str): @@ -108,7 +110,9 @@ def _save_checkpoint(checkpoint_file_path: str): _save_checkpoint(checkpoint_file_paths[0]) if training_finished and cfg.save_last_checkpoint: - checkpoint_last_file_path = os.path.join(cfg.save_dir, checkpoint_last_file_name) + checkpoint_last_file_path = os.path.join( + cfg.save_dir, checkpoint_last_file_name + ) _save_checkpoint(checkpoint_last_file_path) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index d1e5b54ae..f3d4ffbed 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -452,12 +452,16 @@ def validate_and_save( if num_epoch > max_epoch: should_stop = True - logger.info(f"Stopping training due to " - f"num_epoch: {num_epoch} > max_epoch: {max_epoch}") + logger.info( + f"Stopping training due to " + f"num_epoch: {num_epoch} > max_epoch: {max_epoch}" + ) elif num_updates > max_update: should_stop = True - logger.info(f"Stopping training due to " - f"num_updates: {num_updates} > max_update: {max_update}") + logger.info( + f"Stopping training due to " + f"num_updates: {num_updates} > max_update: {max_update}" + ) is_epoch_save_interval = ( end_of_epoch @@ -467,22 +471,28 @@ def validate_and_save( is_successful_update_local_save_interval = ( cfg.checkpoint.local_save_interval_updates > 0 and num_updates > 0 - and num_updates % cfg.checkpoint.local_save_interval_updates == 0 and was_successful_step + and num_updates % cfg.checkpoint.local_save_interval_updates == 0 + and was_successful_step ) is_successful_update_save_interval = ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 - and num_updates % cfg.checkpoint.save_interval_updates == 0 and was_successful_step + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and was_successful_step ) is_successful_update_validate_interval = ( cfg.checkpoint.validate_interval_updates > 0 and num_updates > 0 - and num_updates % cfg.checkpoint.validate_interval_updates == 0 and was_successful_step + and num_updates % cfg.checkpoint.validate_interval_updates == 0 + and was_successful_step ) do_save = ( is_epoch_save_interval - or (is_successful_update_local_save_interval or is_successful_update_save_interval) + or ( + is_successful_update_local_save_interval + or is_successful_update_save_interval + ) or should_stop ) do_validate = ( @@ -502,7 +512,7 @@ def validate_and_save( training_finished=should_stop, async_callback_fn=functools.partial( post_checkpoint_callback, cfg, num_updates, should_stop - ) + ), ) valid_losses = [None]