diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 9c751a215..6a15450a5 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -74,33 +74,30 @@ def save_checkpoint( checkpoint_conds[f"checkpoint{epoch}{suffix}.pt"] = save_for_epoch checkpoint_conds[f"checkpoint_{updates}{suffix}.pt"] = save_for_updates - checkpoint_conds[f"checkpoint_last{suffix}.pt"] = ( - (training_finished and cfg.save_last_checkpoint) - or save_for_epoch - or save_for_updates - ) + checkpoint_last_file_name = f"checkpoint_last{suffix}.pt" extra_state = {"train_iterator": epoch_itr.state_dict()} - checkpoints = [ - os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + checkpoint_file_paths = [ + os.path.join(cfg.save_dir, checkpoint_file_name) + for checkpoint_file_name, cond in checkpoint_conds.items() + if cond ] - if len(checkpoints) > 0: - if PathManager.islink(checkpoints[0]): - PathManager.rm(checkpoints[0]) + def _save_checkpoint(checkpoint_file_path: str): + if PathManager.islink(checkpoint_file_path): + PathManager.rm(checkpoint_file_path) trainer.save_checkpoint( - checkpoints[0], + checkpoint_file_path, extra_state, training_finished=training_finished, - async_callback_fn=async_callback_fn if save_to_NFS else None, - files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None, + async_callback_fn=async_callback_fn, ) write_timer.stop() logger.info( - f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) " + f"Saved checkpoint {checkpoint_file_path} (epoch {epoch} @ {updates} updates) " f"(writing took {write_timer.sum} seconds)" ) @@ -108,6 +105,16 @@ def save_checkpoint( # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both). delete_old_checkpoint_files(cfg, end_of_epoch, suffix) + # If there are checkpoints to save, save the first in the list + if len(checkpoint_file_paths) > 0: + _save_checkpoint(checkpoint_file_paths[0]) + + if training_finished and cfg.save_last_checkpoint: + checkpoint_last_file_path = os.path.join( + cfg.save_dir, checkpoint_last_file_name + ) + _save_checkpoint(checkpoint_last_file_path) + def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str): if not end_of_epoch and cfg.keep_last_updates > 0: diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 011ca0e8a..f3d4ffbed 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -203,10 +203,10 @@ def main(cfg: DictConfig) -> None: disable_iterator_cache=True, ) - max_epoch = cfg.optimization.max_epoch or math.inf train_meter = meters.StopwatchMeter() train_meter.start() - while epoch_itr.next_epoch_idx <= max_epoch: + + while True: # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: @@ -221,6 +221,9 @@ def main(cfg: DictConfig) -> None: disable_iterator_cache=True, ) train_meter.stop() + + # make sure every process finishes before exiting... + distributed_utils.global_barrier() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) @@ -433,6 +436,8 @@ def validate_and_save( end_of_epoch: bool, was_successful_step: bool, ) -> Tuple[List[Optional[float]], bool]: + num_epoch = epoch_itr.epoch + max_epoch = cfg.optimization.max_epoch or math.inf num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf @@ -444,44 +449,57 @@ def validate_and_save( # Stopping conditions (and an additional one based on validation loss later # on) should_stop = False - if num_updates >= max_update: + + if num_epoch > max_epoch: + should_stop = True + logger.info( + f"Stopping training due to " + f"num_epoch: {num_epoch} > max_epoch: {max_epoch}" + ) + elif num_updates > max_update: should_stop = True logger.info( f"Stopping training due to " - f"num_updates: {num_updates} >= max_update: {max_update}" + f"num_updates: {num_updates} > max_update: {max_update}" ) - save_locally = ( + is_epoch_save_interval = ( + end_of_epoch + and cfg.checkpoint.save_interval_epochs > 0 + and num_epoch % cfg.checkpoint.save_interval_epochs == 0 + ) + is_successful_update_local_save_interval = ( cfg.checkpoint.local_save_interval_updates > 0 and num_updates > 0 and num_updates % cfg.checkpoint.local_save_interval_updates == 0 + and was_successful_step ) - save_to_NFS = ( + is_successful_update_save_interval = ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 and num_updates % cfg.checkpoint.save_interval_updates == 0 + and was_successful_step + ) + is_successful_update_validate_interval = ( + cfg.checkpoint.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.checkpoint.validate_interval_updates == 0 + and was_successful_step ) do_save = ( - ( - end_of_epoch - and cfg.checkpoint.save_interval_epochs > 0 - and epoch_itr.epoch % cfg.checkpoint.save_interval_epochs == 0 - ) + is_epoch_save_interval or ( - (save_locally or save_to_NFS) - and num_updates >= cfg.dataset.validate_after_updates - and was_successful_step + is_successful_update_local_save_interval + or is_successful_update_save_interval ) or should_stop ) do_validate = ( should_stop or ( - cfg.dataset.validate_interval_updates > 0 - and num_updates > 0 - and num_updates % cfg.dataset.validate_interval_updates == 0 - and was_successful_step + is_successful_update_validate_interval + and num_updates >= cfg.dataset.validate_after_updates ) ) and not cfg.dataset.disable_validation @@ -494,7 +512,7 @@ def validate_and_save( training_finished=should_stop, async_callback_fn=functools.partial( post_checkpoint_callback, cfg, num_updates, should_stop - ) + ), ) valid_losses = [None]