diff --git a/gpu_tests/test_checkpoint_saving.py b/gpu_tests/test_checkpoint_saving.py index e2803920a..94a1ecc41 100644 --- a/gpu_tests/test_checkpoint_saving.py +++ b/gpu_tests/test_checkpoint_saving.py @@ -12,6 +12,7 @@ from functools import partial, partialmethod import unittest from unittest.mock import patch, Mock, MagicMock +from urllib.parse import urlparse import torch from metaseq.dataclass.configs import DistributedTrainingConfig from metaseq.launcher.opt_baselines import cli_main as sweep_cli_main @@ -81,15 +82,17 @@ def test_checkpoint_saving_and_uploading(self): self.assertEqual(file_names_saved_azure, expected_file_names) for worker_cmd in upload_events: self.assertEqual( - worker_cmd["command"], + worker_cmd["command"][:4], [ "azcopy", "copy", "--cap-mbps", "96.0", - "https://myaccount.blob.core.windows.net/test", ], ) + self.assertTrue( + os.path.basename(worker_cmd["command"][-1]) in expected_file_names + ) self.assertEqual( worker_cmd["checkpoint_model_dir"], common_checkpoint_model_dir ) @@ -235,18 +238,22 @@ def download_checkpoint_mock(blob_url, checkpoint_path, suffix, events): def subprocess_run_mock(cmd, stdout, stderr, events): - # replaces subprocess.run azcopy command that uploads to azure - _, checkpoint_dir, checkpoint_model_dir, checkpoint_file = cmd[4].split("/") - events.append( - { - "type": "upload", - "command": cmd[:4] + cmd[5:], - "checkpoint_dir": checkpoint_dir, - "checkpoint_model_dir": checkpoint_model_dir, - "checkpoint_file": checkpoint_file, - "file_saved_locally": os.path.exists(cmd[4]), - } - ) + source = cmd[4] + dest = cmd[5] + + # Only interested in local -> remote transfers (not asserting remote copies/aliases) + if urlparse(source).scheme == "" and urlparse(dest).scheme == "https": + _, checkpoint_dir, checkpoint_model_dir, checkpoint_file = source.split("/") + events.append( + { + "type": "upload", + "command": cmd[:4] + cmd[5:], + "checkpoint_dir": checkpoint_dir, + "checkpoint_model_dir": checkpoint_model_dir, + "checkpoint_file": checkpoint_file, + "file_saved_locally": os.path.exists(source), + } + ) res = Mock() res.returncode = 0 @@ -254,7 +261,12 @@ def subprocess_run_mock(cmd, stdout, stderr, events): def save_checkpoint_mock( - self, filename, extra_state, training_finished=False, async_callback_fn=None + self, + filename, + extra_state, + training_finished=False, + async_callback_fn=None, + files_to_symlink_to=None, ): logger = logging.getLogger("metaseq.trainer") """Save all training state in a checkpoint file.""" @@ -278,7 +290,7 @@ def perform_save(): logger.info(f"Beginning asynchronous torch.save to {filename}") torch.save(state_dict, filename) if async_callback_fn is not None: - async_callback_fn(filename) + async_callback_fn(filename, files_to_symlink_to) logger.info(f"Asynchronous torch.save to {filename} complete.") except Exception as e: logger.exception(f"Asynchronous save failed: {e}") diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index faf2f0320..48b78357f 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -95,6 +95,7 @@ def save_checkpoint( extra_state, training_finished=training_finished, async_callback_fn=async_callback_fn if save_to_NFS else None, + files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None, ) write_timer.stop() @@ -103,6 +104,50 @@ def save_checkpoint( f"(writing took {write_timer.sum} seconds)" ) + # See if there's any older checkpoints to delete after saving a new one. + # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both). + delete_old_checkpoint_files(cfg, end_of_epoch, suffix) + + +def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str): + if not end_of_epoch and cfg.keep_last_updates > 0: + # remove old checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + if cfg.keep_last_epochs > 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + +# Reference: +# https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538 +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): + """Retrieves all checkpoints found in `path` directory. + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index c131a270c..86eb3107d 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -8,6 +8,7 @@ """ import argparse +from datetime import datetime import functools import logging import math @@ -18,6 +19,7 @@ import socket import re from typing import Dict, Optional, Any, List, Tuple, Callable +from urllib.parse import urlparse import warnings import numpy as np @@ -511,37 +513,51 @@ def _checkpoint_add_directory(basename): return m[1], f"checkpoint{m[3]}" -def post_checkpoint_callback(cfg, num_updates, training_finished, filename): +def _get_basename(path): + res = urlparse(path) + if res.scheme: + return os.path.basename(res.path) + else: + return os.path.basename(path) + + +def _get_destination_path(path, destination): + """Calculates the destination path with handling for remote paths.""" + basename = _get_basename(path) + res = urlparse(destination) + if res.scheme: + new_path = os.path.join(res.path, basename) + res = res._replace(path=new_path) + return res.geturl() + else: + return os.path.join(destination, basename) + + +def post_checkpoint_callback( + cfg, num_updates, training_finished, filename, files_to_symlink_to +): if cfg.checkpoint.cloud_upload_path is not None: if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path: - azcopy_logs = filename + "_azcopy_logs" - os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10" - os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs - os.makedirs(azcopy_logs, exist_ok=True) - logger.info( - f"preparing to azcopy {filename} to {cfg.checkpoint.cloud_upload_path}; logs in {azcopy_logs}" + azcopy_log_dir = os.path.dirname(filename) + final_path = _get_destination_path( + filename, cfg.checkpoint.cloud_upload_path ) - cmd = [ - "azcopy", # TODO(susanz): require azcopy to be installed. - "copy", - "--cap-mbps", - "96.0", - filename, - cfg.checkpoint.cloud_upload_path, - ] - res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if res.returncode != 0: - print("Error: {}, azcopy failed".format(res.returncode)) - print("Azcopy stdout = {}".format(res.stdout)) - sys.exit(1) + _copy_to_azure(filename, final_path, azcopy_log_dir) + # Delete original checkpoint on local storage # TODO make this configurable - logger.info( - f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}" - ) os.remove(filename) + + # Azure Blob doesn't support symlinks so make full copies + if files_to_symlink_to: + for other_checkpoint in files_to_symlink_to: + dest = _get_destination_path( + other_checkpoint, cfg.checkpoint.cloud_upload_path + ) + _copy_to_azure(final_path, dest, azcopy_log_dir) + elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"): - path, basename = os.path.split(filename) + basename = os.path.basename(filename) checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename) destination_checkpoints_dir = cfg.checkpoint.cloud_upload_path[4:] temporary_checkpoint_file = f"_{checkpoint_file}" @@ -565,6 +581,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): ) logger.info(f"Renaming {temporary_checkpoint_file} -> {checkpoint_file}") + final_path = os.path.join( + destination_checkpoints_dir, checkpoint_dir, checkpoint_file + ) # atomic rename _checkpointfile -> checkpointfile # this way we know that if present the checkpoint file is complete os.rename( @@ -573,12 +592,20 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): checkpoint_dir, temporary_checkpoint_file, ), - os.path.join( - destination_checkpoints_dir, checkpoint_dir, checkpoint_file - ), + final_path, ) os.remove(filename) + if files_to_symlink_to: + dest_dir = os.path.dirname(final_path) + for other_checkpoint in files_to_symlink_to: + dest = _get_destination_path(other_checkpoint, dest_dir) + if PathManager.islink(dest): + PathManager.rm(dest) + assert PathManager.symlink( + final_path, dest + ), f"Failed to symlink {final_path} to {dest}" + # Start running evals on uploaded checkpoint nfs_evaluation( cfg, @@ -592,13 +619,18 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): try: # PathManager only supports writing to S3, but this function call # can be replaced with other APIs for copying checkpoints. - PathManager.copy_from_local( - filename, - os.path.join( - cfg.checkpoint.cloud_upload_path, os.path.basename(filename) - ), - overwrite=True, + final_path = _get_destination_path( + filename, cfg.checkpoint.cloud_upload_path ) + PathManager.copy_from_local(filename, final_path, overwrite=True) + + # Some non-native PathHandlers don't support symlinks so default to full copies + if files_to_symlink_to: + for other_checkpoint in files_to_symlink_to: + dest = _get_destination_path( + other_checkpoint, cfg.checkpoint.cloud_upload_path + ) + PathManager.copy(final_path, dest, overwrite=True) except (FileNotFoundError, AssertionError) as e: logger.info(f"could not upload {filename}: {e}") @@ -664,6 +696,31 @@ def nfs_evaluation( ) +def _copy_to_azure(source, destination, log_dir): + # /dir/checkpoint_last.pt -> /dir/checkpoint_last.pt_azcopy_logs_2000-01-01T00_00_00 + basename = _get_basename(destination) + timestamp = datetime.utcnow().isoformat().replace(":", "_")[:-7] + azcopy_logs = os.path.join(log_dir, f"{basename}_azcopy_logs_{timestamp}") + os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10" + os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs + os.makedirs(azcopy_logs, exist_ok=True) + logger.info(f"preparing to azcopy {source} to {destination}; logs in {azcopy_logs}") + cmd = [ + "azcopy", # TODO(susanz): require azcopy to be installed. + "copy", + "--cap-mbps", + "96.0", + source, + destination, + ] + res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if res.returncode != 0: + print("Error: {}, azcopy failed".format(res.returncode)) + print("Azcopy stdout = {}".format(res.stdout)) + sys.exit(1) + logger.info(f"Successfully copied {source} to {destination}") + + def _run_azcopy(cmd, stdout, stderr): return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) diff --git a/metaseq/trainer.py b/metaseq/trainer.py index faa888d11..eacc46c25 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -421,7 +421,12 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]: return state_dicts def save_checkpoint( - self, filename, extra_state, training_finished=False, async_callback_fn=None + self, + filename, + extra_state, + training_finished=False, + async_callback_fn=None, + files_to_symlink_to=None, ): """Save all training state in a checkpoint file.""" @@ -445,7 +450,7 @@ def save_checkpoint( def perform_save(): try: logger.info(f"Beginning asynchronous torch.save to {filename}") - async_callback_fn(filename) + async_callback_fn(filename, files_to_symlink_to) logger.info(f"Asynchronous torch.save to {filename} complete.") except Exception as e: logger.exception(f"Asynchronous save failed: {e}") diff --git a/tests/test_cli_train.py b/tests/test_cli_train.py new file mode 100644 index 000000000..89d2ceac1 --- /dev/null +++ b/tests/test_cli_train.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import array +import random +import os +import tempfile +import unittest +from unittest.mock import patch, MagicMock + +from metaseq.cli.train import post_checkpoint_callback, _get_destination_path +from metaseq.dataclass.configs import MetaseqConfig + + +def create_local_test_file(path, length=4096): + value = random.randint(0, 16) + with open(path, "wb") as f: + array.array("b", [value] * length).tofile(f) + + +class TestPostCheckpointCallback(unittest.TestCase): + def test_destination_path(self): + self.assertEqual( + _get_destination_path("/path/ckpt.pt", "/other"), + "/other/ckpt.pt", + ) + self.assertEqual( + _get_destination_path( + "/path/ckpt.pt", "https://acc.blob.core.windows.net/other?q=1" + ), + "https://acc.blob.core.windows.net/other/ckpt.pt?q=1", + ) + self.assertEqual( + _get_destination_path( + "https://acc.blob.core.windows.net/path/ckpt.pt?q=1", "/other" + ), + "/other/ckpt.pt", + ) + + def test_nfs_copy(self): + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as nfs_dir: + checkpoint_path = os.path.join( + local_dir, "checkpoint_100-model_part-0-shard0.pt" + ) + create_local_test_file(checkpoint_path) + + cfg = MetaseqConfig() + cfg.checkpoint.cloud_upload_path = f"nfs:{nfs_dir}" + # Prevent evals + cfg.checkpoint.nfs_eval_frequency = 0 + + post_checkpoint_callback( + cfg=cfg, + num_updates=10, + training_finished=False, + filename=checkpoint_path, + files_to_symlink_to=None, + ) + + expected_path = os.path.join( + nfs_dir, "checkpoint_100/checkpoint-model_part-0-shard0.pt" + ) + self.assertTrue( + os.path.exists(expected_path), f"File should exist: {expected_path}" + ) + + def test_nfs_copy_with_symlinks(self): + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as nfs_dir: + checkpoint_path = os.path.join(local_dir, "checkpoint_10.pt") + create_local_test_file(checkpoint_path) + + cfg = MetaseqConfig() + cfg.checkpoint.cloud_upload_path = f"nfs:{nfs_dir}" + # Prevent evals + cfg.checkpoint.nfs_eval_frequency = 0 + + post_checkpoint_callback( + cfg=cfg, + num_updates=10, + training_finished=False, + filename=checkpoint_path, + files_to_symlink_to=[os.path.join(local_dir, "checkpoint_last.pt")], + ) + + self.assertTrue( + os.path.exists(os.path.join(nfs_dir, "checkpoint_10/checkpoint.pt")) + ) + self.assertTrue( + os.path.islink( + os.path.join(nfs_dir, "checkpoint_10/checkpoint_last.pt") + ) + ) + + def assert_azcopy(self, mock, src, dst): + def _match(c): + _, args, _ = c + cmd = args[0] + return cmd[-2] == src and cmd[-1] == dst + + self.assertTrue(any([_match(c) for c in mock.mock_calls])) + + def test_azure_blob_with_symlinks(self): + mock_azcopy = MagicMock(return_value=MagicMock(returncode=0)) + with patch("metaseq.cli.train._run_azcopy", mock_azcopy): + with tempfile.TemporaryDirectory() as local_dir: + checkpoint_path = os.path.join(local_dir, "checkpoint_10.pt") + create_local_test_file(checkpoint_path) + + upload_path = "https://testaccount.blob.core.windows.net/dest?q=1" + cfg = MetaseqConfig() + cfg.checkpoint.cloud_upload_path = upload_path + + post_checkpoint_callback( + cfg=cfg, + num_updates=10, + training_finished=False, + filename=checkpoint_path, + files_to_symlink_to=[os.path.join(local_dir, "checkpoint_last.pt")], + ) + + upload_src = "https://testaccount.blob.core.windows.net/dest/checkpoint_10.pt?q=1" + upload_dst = "https://testaccount.blob.core.windows.net/dest/checkpoint_last.pt?q=1" + self.assert_azcopy(mock_azcopy, checkpoint_path, upload_path) + self.assert_azcopy(mock_azcopy, upload_src, upload_dst) + + +if __name__ == "__main__": + unittest.main()