From f294e507a2aea1904a28b5f3e2ce9b0511e0355e Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 19:32:07 -0400 Subject: [PATCH 01/12] create symlink for different checkpoint names --- metaseq/checkpoint_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index faf2f0320..76dfb7e0a 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -97,6 +97,13 @@ def save_checkpoint( async_callback_fn=async_callback_fn if save_to_NFS else None, ) + if len(checkpoints) > 1: + # Create symlink between identical checkpoints (differing in naming for epoch/update/last). + for other_checkpoint in checkpoints[1:]: + assert PathManager.symlink( + checkpoints[0], other_checkpoint, overwrite=True + ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}" + write_timer.stop() logger.info( f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) " From 6dc36736b70f89b48cd6b65989cd0009ee9364c3 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 19:40:20 -0400 Subject: [PATCH 02/12] add back checkpoint deletion logic, configurable via keep_last_updates or keep_last_epochs --- metaseq/checkpoint_utils.py | 44 +++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 76dfb7e0a..d2f86698d 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -110,6 +110,50 @@ def save_checkpoint( f"(writing took {write_timer.sum} seconds)" ) + # See if there's any older checkpoints to delete after saving a new one. + # Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both). + delete_old_checkpoint_files(cfg, end_of_epoch, suffix) + + +def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str): + if not end_of_epoch and cfg.keep_last_updates > 0: + # remove old checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + if cfg.keep_last_epochs > 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + +# Reference: +# https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538 +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): + """Retrieves all checkpoints found in `path` directory. + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ From 7db60a8867f786bdf24c76d0af83ea4f318b3337 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 19:55:15 -0400 Subject: [PATCH 03/12] fix symlink --- metaseq/checkpoint_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index d2f86698d..452453229 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -100,8 +100,10 @@ def save_checkpoint( if len(checkpoints) > 1: # Create symlink between identical checkpoints (differing in naming for epoch/update/last). for other_checkpoint in checkpoints[1:]: + if PathManager.islink(other_checkpoint): + PathManager.rm(other_checkpoint) assert PathManager.symlink( - checkpoints[0], other_checkpoint, overwrite=True + checkpoints[0], other_checkpoint ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}" write_timer.stop() From da2b1206cb49f48d81d86787488e34cd1cc45ea7 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 21:01:48 -0400 Subject: [PATCH 04/12] test --- metaseq/checkpoint_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 452453229..99aa3fb16 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -97,14 +97,14 @@ def save_checkpoint( async_callback_fn=async_callback_fn if save_to_NFS else None, ) - if len(checkpoints) > 1: - # Create symlink between identical checkpoints (differing in naming for epoch/update/last). - for other_checkpoint in checkpoints[1:]: - if PathManager.islink(other_checkpoint): - PathManager.rm(other_checkpoint) - assert PathManager.symlink( - checkpoints[0], other_checkpoint - ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}" + # if len(checkpoints) > 1: + # # Create symlink between identical checkpoints (differing in naming for epoch/update/last). + # for other_checkpoint in checkpoints[1:]: + # if PathManager.islink(other_checkpoint): + # PathManager.rm(other_checkpoint) + # assert PathManager.symlink( + # checkpoints[0], other_checkpoint + # ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}" write_timer.stop() logger.info( From 9c92e4ecb3308d485e34406323bf76daa5170cf1 Mon Sep 17 00:00:00 2001 From: suchenzang Date: Sun, 12 Mar 2023 21:30:02 -0400 Subject: [PATCH 05/12] stub in where symlinking/copying might happen --- metaseq/checkpoint_utils.py | 10 +--------- metaseq/cli/train.py | 17 ++++++++++++++++- metaseq/trainer.py | 4 ++-- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 99aa3fb16..48b78357f 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -95,17 +95,9 @@ def save_checkpoint( extra_state, training_finished=training_finished, async_callback_fn=async_callback_fn if save_to_NFS else None, + files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None, ) - # if len(checkpoints) > 1: - # # Create symlink between identical checkpoints (differing in naming for epoch/update/last). - # for other_checkpoint in checkpoints[1:]: - # if PathManager.islink(other_checkpoint): - # PathManager.rm(other_checkpoint) - # assert PathManager.symlink( - # checkpoints[0], other_checkpoint - # ), f"Failed to symlink {checkpoints[0]} to {other_checkpoint}" - write_timer.stop() logger.info( f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) " diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index c131a270c..b7b8a3ba9 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -511,7 +511,7 @@ def _checkpoint_add_directory(basename): return m[1], f"checkpoint{m[3]}" -def post_checkpoint_callback(cfg, num_updates, training_finished, filename): +def post_checkpoint_callback(cfg, num_updates, training_finished, filename, files_to_symlink_to): if cfg.checkpoint.cloud_upload_path is not None: if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path: azcopy_logs = filename + "_azcopy_logs" @@ -540,6 +540,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}" ) os.remove(filename) + + # TODO[Susan]: Add symlink logic here? Check what cloud_upload_path is being used for Uriel's jobs. + elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"): path, basename = os.path.split(filename) checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename) @@ -579,6 +582,8 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): ) os.remove(filename) + # TODO[Susan]: Add symlink logic here. + # Start running evals on uploaded checkpoint nfs_evaluation( cfg, @@ -602,6 +607,16 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename): except (FileNotFoundError, AssertionError) as e: logger.info(f"could not upload {filename}: {e}") + # TODO[Susan]: Add symlink logic here. + + # if files_to_symlink_to is not None and len(files_to_symlink_to) > 1: + # for other_checkpoint in files_to_symlink_to: + # if PathManager.islink(other_checkpoint): + # PathManager.rm(other_checkpoint) + # assert PathManager.symlink( + # filename, other_checkpoint + # ), f"Failed to symlink {filename} to {other_checkpoint}" + def nfs_evaluation( cfg, num_updates, training_finished, checkpoint_dir, destination_checkpoints_dir diff --git a/metaseq/trainer.py b/metaseq/trainer.py index faa888d11..d28b3b5cb 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -421,7 +421,7 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]: return state_dicts def save_checkpoint( - self, filename, extra_state, training_finished=False, async_callback_fn=None + self, filename, extra_state, training_finished=False, async_callback_fn=None, files_to_symlink_to=None ): """Save all training state in a checkpoint file.""" @@ -445,7 +445,7 @@ def save_checkpoint( def perform_save(): try: logger.info(f"Beginning asynchronous torch.save to {filename}") - async_callback_fn(filename) + async_callback_fn(filename, files_to_symlink_to) logger.info(f"Asynchronous torch.save to {filename} complete.") except Exception as e: logger.exception(f"Asynchronous save failed: {e}") From 668357751c2920430ba4fb24d75b144cebb97b71 Mon Sep 17 00:00:00 2001 From: davides Date: Wed, 29 Mar 2023 08:29:40 -0700 Subject: [PATCH 06/12] Add symlinks when checkpointing --- metaseq/cli/train.py | 115 ++++++++++++++++++++++++---------------- tests/test_cli_train.py | 85 +++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 46 deletions(-) create mode 100644 tests/test_cli_train.py diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index b7b8a3ba9..5aed7bb72 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -8,6 +8,7 @@ """ import argparse +from datetime import datetime import functools import logging import math @@ -511,40 +512,33 @@ def _checkpoint_add_directory(basename): return m[1], f"checkpoint{m[3]}" -def post_checkpoint_callback(cfg, num_updates, training_finished, filename, files_to_symlink_to): +def _get_destination_path(path, destination): + return os.path.join(destination, os.path.basename(path)) + + +def post_checkpoint_callback( + cfg, num_updates, training_finished, filename, files_to_symlink_to +): if cfg.checkpoint.cloud_upload_path is not None: if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path: - azcopy_logs = filename + "_azcopy_logs" - os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10" - os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs - os.makedirs(azcopy_logs, exist_ok=True) - logger.info( - f"preparing to azcopy {filename} to {cfg.checkpoint.cloud_upload_path}; logs in {azcopy_logs}" - ) - cmd = [ - "azcopy", # TODO(susanz): require azcopy to be installed. - "copy", - "--cap-mbps", - "96.0", - filename, - cfg.checkpoint.cloud_upload_path, - ] - res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if res.returncode != 0: - print("Error: {}, azcopy failed".format(res.returncode)) - print("Azcopy stdout = {}".format(res.stdout)) - sys.exit(1) + azcopy_log_dir = os.path.dirname(filename) + _copy_to_azure(filename, cfg.checkpoint.cloud_upload_path, azcopy_log_dir) + # Delete original checkpoint on local storage # TODO make this configurable - logger.info( - f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}" - ) os.remove(filename) - # TODO[Susan]: Add symlink logic here? Check what cloud_upload_path is being used for Uriel's jobs. + # Azure Blob doesn't support symlinks so make full copies + source = _get_destination_path(filename, cfg.checkpoint.cloud_upload_path) + if files_to_symlink_to: + for other_checkpoint in files_to_symlink_to: + dest = _get_destination_path( + other_checkpoint, cfg.checkpoint.cloud_upload_path + ) + _copy_to_azure(source, dest, azcopy_log_dir) elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"): - path, basename = os.path.split(filename) + basename = os.path.basename(filename) checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename) destination_checkpoints_dir = cfg.checkpoint.cloud_upload_path[4:] temporary_checkpoint_file = f"_{checkpoint_file}" @@ -568,6 +562,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename, file ) logger.info(f"Renaming {temporary_checkpoint_file} -> {checkpoint_file}") + final_path = os.path.join( + destination_checkpoints_dir, checkpoint_dir, checkpoint_file + ) # atomic rename _checkpointfile -> checkpointfile # this way we know that if present the checkpoint file is complete os.rename( @@ -576,13 +573,19 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename, file checkpoint_dir, temporary_checkpoint_file, ), - os.path.join( - destination_checkpoints_dir, checkpoint_dir, checkpoint_file - ), + final_path, ) os.remove(filename) - # TODO[Susan]: Add symlink logic here. + if files_to_symlink_to: + dest_dir = os.path.dirname(final_path) + for other_checkpoint in files_to_symlink_to: + dest = _get_destination_path(other_checkpoint, dest_dir) + if PathManager.islink(dest): + PathManager.rm(dest) + assert PathManager.symlink( + final_path, dest + ), f"Failed to symlink {final_path} to {dest}" # Start running evals on uploaded checkpoint nfs_evaluation( @@ -597,26 +600,21 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename, file try: # PathManager only supports writing to S3, but this function call # can be replaced with other APIs for copying checkpoints. - PathManager.copy_from_local( - filename, - os.path.join( - cfg.checkpoint.cloud_upload_path, os.path.basename(filename) - ), - overwrite=True, + final_path = _get_destination_path( + filename, cfg.checkpoint.cloud_upload_path ) + PathManager.copy_from_local(filename, final_path, overwrite=True) + + # Some non-native PathHandlers don't support symlinks so default to full copies + if files_to_symlink_to: + for other_checkpoint in files_to_symlink_to: + dest = _get_destination_path( + other_checkpoint, cfg.checkpoint.cloud_upload_path + ) + PathManager.copy(final_path, dest, overwrite=True) except (FileNotFoundError, AssertionError) as e: logger.info(f"could not upload {filename}: {e}") - # TODO[Susan]: Add symlink logic here. - - # if files_to_symlink_to is not None and len(files_to_symlink_to) > 1: - # for other_checkpoint in files_to_symlink_to: - # if PathManager.islink(other_checkpoint): - # PathManager.rm(other_checkpoint) - # assert PathManager.symlink( - # filename, other_checkpoint - # ), f"Failed to symlink {filename} to {other_checkpoint}" - def nfs_evaluation( cfg, num_updates, training_finished, checkpoint_dir, destination_checkpoints_dir @@ -679,6 +677,31 @@ def nfs_evaluation( ) +def _copy_to_azure(source, destination, log_dir): + # /dir/checkpoint_last.pt -> /dir/checkpoint_last.pt_azcopy_logs_2000-01-01T00_00_00 + basename = os.path.basename(destination) + timestamp = datetime.utcnow().isoformat().replace(":", "_")[:-7] + azcopy_logs = os.path.join(log_dir, f"{basename}_azcopy_logs_{timestamp}") + os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10" + os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs + os.makedirs(azcopy_logs, exist_ok=True) + logger.info(f"preparing to azcopy {source} to {destination}; logs in {azcopy_logs}") + cmd = [ + "azcopy", # TODO(susanz): require azcopy to be installed. + "copy", + "--cap-mbps", + "96.0", + source, + destination, + ] + res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if res.returncode != 0: + print("Error: {}, azcopy failed".format(res.returncode)) + print("Azcopy stdout = {}".format(res.stdout)) + sys.exit(1) + logger.info(f"Successfully copied {source} to {destination}") + + def _run_azcopy(cmd, stdout, stderr): return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) diff --git a/tests/test_cli_train.py b/tests/test_cli_train.py new file mode 100644 index 000000000..27e72a94f --- /dev/null +++ b/tests/test_cli_train.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import array +import random +import os +import tempfile +import unittest + +from metaseq.cli.train import post_checkpoint_callback +from metaseq.dataclass.configs import MetaseqConfig + + +def create_local_test_file(path, length=4096): + value = random.randint(0, 16) + with open(path, "wb") as f: + array.array("b", [value] * length).tofile(f) + + +class TestPostCheckpointCallback(unittest.TestCase): + def test_nfs_copy(self): + with ( + tempfile.TemporaryDirectory() as local_dir, + tempfile.TemporaryDirectory() as nfs_dir, + ): + checkpoint_path = os.path.join( + local_dir, "checkpoint_100-model_part-0-shard0.pt" + ) + create_local_test_file(checkpoint_path) + + cfg = MetaseqConfig() + cfg.checkpoint.cloud_upload_path = f"nfs:{nfs_dir}" + # Prevent evals + cfg.checkpoint.nfs_eval_frequency = 0 + + post_checkpoint_callback( + cfg=cfg, + num_updates=10, + training_finished=False, + filename=checkpoint_path, + files_to_symlink_to=None, + ) + + expected_path = os.path.join( + nfs_dir, "checkpoint_100/checkpoint-model_part-0-shard0.pt" + ) + self.assertTrue( + os.path.exists(expected_path), f"File should exist: {expected_path}" + ) + + def test_nfs_copy_with_symlinks(self): + with ( + tempfile.TemporaryDirectory() as local_dir, + tempfile.TemporaryDirectory() as nfs_dir, + ): + checkpoint_path = os.path.join(local_dir, "checkpoint_10.pt") + create_local_test_file(checkpoint_path) + + cfg = MetaseqConfig() + cfg.checkpoint.cloud_upload_path = f"nfs:{nfs_dir}" + # Prevent evals + cfg.checkpoint.nfs_eval_frequency = 0 + + post_checkpoint_callback( + cfg=cfg, + num_updates=10, + training_finished=False, + filename=checkpoint_path, + files_to_symlink_to=[os.path.join(local_dir, "checkpoint_last.pt")], + ) + + self.assertTrue( + os.path.exists(os.path.join(nfs_dir, "checkpoint_10/checkpoint.pt")) + ) + self.assertTrue( + os.path.islink( + os.path.join(nfs_dir, "checkpoint_10/checkpoint_last.pt") + ) + ) + + +if __name__ == "__main__": + unittest.main() From 007be47daecd1f52c5ec5ae2d928802c88ffc91c Mon Sep 17 00:00:00 2001 From: davides Date: Wed, 29 Mar 2023 11:53:14 -0700 Subject: [PATCH 07/12] Fix up handling of remote paths --- metaseq/cli/train.py | 21 ++++++++++++++-- tests/test_cli_train.py | 53 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 5aed7bb72..288879420 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -19,6 +19,7 @@ import socket import re from typing import Dict, Optional, Any, List, Tuple, Callable +from urllib.parse import urlparse import warnings import numpy as np @@ -512,8 +513,24 @@ def _checkpoint_add_directory(basename): return m[1], f"checkpoint{m[3]}" +def _get_basename(path): + res = urlparse(path) + if res.scheme: + return os.path.basename(res.path) + else: + return os.path.basename(path) + + def _get_destination_path(path, destination): - return os.path.join(destination, os.path.basename(path)) + """Calculates the destination path with handling for remote paths.""" + basename = _get_basename(path) + res = urlparse(destination) + if res.scheme: + new_path = os.path.join(res.path, basename) + res = res._replace(path=new_path) + return res.geturl() + else: + return os.path.join(destination, basename) def post_checkpoint_callback( @@ -679,7 +696,7 @@ def nfs_evaluation( def _copy_to_azure(source, destination, log_dir): # /dir/checkpoint_last.pt -> /dir/checkpoint_last.pt_azcopy_logs_2000-01-01T00_00_00 - basename = os.path.basename(destination) + basename = _get_basename(destination) timestamp = datetime.utcnow().isoformat().replace(":", "_")[:-7] azcopy_logs = os.path.join(log_dir, f"{basename}_azcopy_logs_{timestamp}") os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10" diff --git a/tests/test_cli_train.py b/tests/test_cli_train.py index 27e72a94f..e3a7af813 100644 --- a/tests/test_cli_train.py +++ b/tests/test_cli_train.py @@ -8,8 +8,9 @@ import os import tempfile import unittest +from unittest.mock import patch, MagicMock -from metaseq.cli.train import post_checkpoint_callback +from metaseq.cli.train import post_checkpoint_callback, _get_destination_path from metaseq.dataclass.configs import MetaseqConfig @@ -20,6 +21,24 @@ def create_local_test_file(path, length=4096): class TestPostCheckpointCallback(unittest.TestCase): + def test_destination_path(self): + self.assertEqual( + _get_destination_path("/path/ckpt.pt", "/other"), + "/other/ckpt.pt", + ) + self.assertEqual( + _get_destination_path( + "/path/ckpt.pt", "https://acc.blob.core.windows.net/other?q=1" + ), + "https://acc.blob.core.windows.net/other/ckpt.pt?q=1", + ) + self.assertEqual( + _get_destination_path( + "https://acc.blob.core.windows.net/path/ckpt.pt?q=1", "/other" + ), + "/other/ckpt.pt", + ) + def test_nfs_copy(self): with ( tempfile.TemporaryDirectory() as local_dir, @@ -80,6 +99,38 @@ def test_nfs_copy_with_symlinks(self): ) ) + def assert_azcopy(self, mock, src, dst): + def _match(c): + _, args, _ = c + cmd = args[0] + return cmd[-2] == src and cmd[-1] == dst + + self.assertTrue(any([_match(c) for c in mock.mock_calls])) + + def test_azure_blob_with_symlinks(self): + mock_azcopy = MagicMock(return_value=MagicMock(returncode=0)) + with patch("metaseq.cli.train._run_azcopy", mock_azcopy): + with tempfile.TemporaryDirectory() as local_dir: + checkpoint_path = os.path.join(local_dir, "checkpoint_10.pt") + create_local_test_file(checkpoint_path) + + upload_path = "https://testaccount.blob.core.windows.net/dest?q=1" + cfg = MetaseqConfig() + cfg.checkpoint.cloud_upload_path = upload_path + + post_checkpoint_callback( + cfg=cfg, + num_updates=10, + training_finished=False, + filename=checkpoint_path, + files_to_symlink_to=[os.path.join(local_dir, "checkpoint_last.pt")], + ) + + upload_src = "https://testaccount.blob.core.windows.net/dest/checkpoint_10.pt?q=1" + upload_dst = "https://testaccount.blob.core.windows.net/dest/checkpoint_last.pt?q=1" + self.assert_azcopy(mock_azcopy, checkpoint_path, upload_path) + self.assert_azcopy(mock_azcopy, upload_src, upload_dst) + if __name__ == "__main__": unittest.main() From c76c4ecbe24ac44d134a659eef69ac38ec2a714c Mon Sep 17 00:00:00 2001 From: davides Date: Wed, 29 Mar 2023 12:23:49 -0700 Subject: [PATCH 08/12] More fixes for azcopy; fix broken gpu_tests; run linters --- gpu_tests/test_checkpoint_saving.py | 7 ++++++- metaseq/cli/train.py | 8 +++++--- metaseq/trainer.py | 7 ++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/gpu_tests/test_checkpoint_saving.py b/gpu_tests/test_checkpoint_saving.py index e2803920a..c84e78ef1 100644 --- a/gpu_tests/test_checkpoint_saving.py +++ b/gpu_tests/test_checkpoint_saving.py @@ -254,7 +254,12 @@ def subprocess_run_mock(cmd, stdout, stderr, events): def save_checkpoint_mock( - self, filename, extra_state, training_finished=False, async_callback_fn=None + self, + filename, + extra_state, + training_finished=False, + async_callback_fn=None, + files_to_symlink_to=None, ): logger = logging.getLogger("metaseq.trainer") """Save all training state in a checkpoint file.""" diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 288879420..86eb3107d 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -539,20 +539,22 @@ def post_checkpoint_callback( if cfg.checkpoint.cloud_upload_path is not None: if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path: azcopy_log_dir = os.path.dirname(filename) - _copy_to_azure(filename, cfg.checkpoint.cloud_upload_path, azcopy_log_dir) + final_path = _get_destination_path( + filename, cfg.checkpoint.cloud_upload_path + ) + _copy_to_azure(filename, final_path, azcopy_log_dir) # Delete original checkpoint on local storage # TODO make this configurable os.remove(filename) # Azure Blob doesn't support symlinks so make full copies - source = _get_destination_path(filename, cfg.checkpoint.cloud_upload_path) if files_to_symlink_to: for other_checkpoint in files_to_symlink_to: dest = _get_destination_path( other_checkpoint, cfg.checkpoint.cloud_upload_path ) - _copy_to_azure(source, dest, azcopy_log_dir) + _copy_to_azure(final_path, dest, azcopy_log_dir) elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"): basename = os.path.basename(filename) diff --git a/metaseq/trainer.py b/metaseq/trainer.py index d28b3b5cb..eacc46c25 100644 --- a/metaseq/trainer.py +++ b/metaseq/trainer.py @@ -421,7 +421,12 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]: return state_dicts def save_checkpoint( - self, filename, extra_state, training_finished=False, async_callback_fn=None, files_to_symlink_to=None + self, + filename, + extra_state, + training_finished=False, + async_callback_fn=None, + files_to_symlink_to=None, ): """Save all training state in a checkpoint file.""" From 925d20ba5885bce8e01e5c9336c4b4bd747b978c Mon Sep 17 00:00:00 2001 From: davides Date: Thu, 30 Mar 2023 04:47:39 -0700 Subject: [PATCH 09/12] More linter fixes --- tests/test_cli_train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_cli_train.py b/tests/test_cli_train.py index e3a7af813..b8eac55f8 100644 --- a/tests/test_cli_train.py +++ b/tests/test_cli_train.py @@ -40,10 +40,7 @@ def test_destination_path(self): ) def test_nfs_copy(self): - with ( - tempfile.TemporaryDirectory() as local_dir, - tempfile.TemporaryDirectory() as nfs_dir, - ): + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as nfs_dir: checkpoint_path = os.path.join( local_dir, "checkpoint_100-model_part-0-shard0.pt" ) From 2af88072f9982d5c19b0633c2659ef57519dd594 Mon Sep 17 00:00:00 2001 From: David Esiobu Date: Thu, 30 Mar 2023 10:30:31 -0700 Subject: [PATCH 10/12] Update gpu_tests to account for new symlink behavior --- gpu_tests/test_checkpoint_saving.py | 35 ++++++++++++++++------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/gpu_tests/test_checkpoint_saving.py b/gpu_tests/test_checkpoint_saving.py index c84e78ef1..bec481a0f 100644 --- a/gpu_tests/test_checkpoint_saving.py +++ b/gpu_tests/test_checkpoint_saving.py @@ -12,6 +12,7 @@ from functools import partial, partialmethod import unittest from unittest.mock import patch, Mock, MagicMock +from urllib.parse import urlparse import torch from metaseq.dataclass.configs import DistributedTrainingConfig from metaseq.launcher.opt_baselines import cli_main as sweep_cli_main @@ -81,15 +82,15 @@ def test_checkpoint_saving_and_uploading(self): self.assertEqual(file_names_saved_azure, expected_file_names) for worker_cmd in upload_events: self.assertEqual( - worker_cmd["command"], + worker_cmd["command"][:4], [ "azcopy", "copy", "--cap-mbps", "96.0", - "https://myaccount.blob.core.windows.net/test", ], ) + self.assertTrue(os.path.basename(worker_cmd["command"][-1]) in expected_file_names) self.assertEqual( worker_cmd["checkpoint_model_dir"], common_checkpoint_model_dir ) @@ -235,18 +236,22 @@ def download_checkpoint_mock(blob_url, checkpoint_path, suffix, events): def subprocess_run_mock(cmd, stdout, stderr, events): - # replaces subprocess.run azcopy command that uploads to azure - _, checkpoint_dir, checkpoint_model_dir, checkpoint_file = cmd[4].split("/") - events.append( - { - "type": "upload", - "command": cmd[:4] + cmd[5:], - "checkpoint_dir": checkpoint_dir, - "checkpoint_model_dir": checkpoint_model_dir, - "checkpoint_file": checkpoint_file, - "file_saved_locally": os.path.exists(cmd[4]), - } - ) + source = cmd[4] + dest = cmd[5] + + # Only interested in local -> remote transfers (not asserting remote copies/aliases) + if urlparse(source).scheme == "" and urlparse(dest).scheme == "https": + _, checkpoint_dir, checkpoint_model_dir, checkpoint_file = source.split("/") + events.append( + { + "type": "upload", + "command": cmd[:4] + cmd[5:], + "checkpoint_dir": checkpoint_dir, + "checkpoint_model_dir": checkpoint_model_dir, + "checkpoint_file": checkpoint_file, + "file_saved_locally": os.path.exists(source), + } + ) res = Mock() res.returncode = 0 @@ -283,7 +288,7 @@ def perform_save(): logger.info(f"Beginning asynchronous torch.save to {filename}") torch.save(state_dict, filename) if async_callback_fn is not None: - async_callback_fn(filename) + async_callback_fn(filename, files_to_symlink_to) logger.info(f"Asynchronous torch.save to {filename} complete.") except Exception as e: logger.exception(f"Asynchronous save failed: {e}") From b417ced533a2ea509a72e42e0a12b62fa5d7a3e9 Mon Sep 17 00:00:00 2001 From: David Esiobu Date: Thu, 30 Mar 2023 10:31:31 -0700 Subject: [PATCH 11/12] Lint --- gpu_tests/test_checkpoint_saving.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gpu_tests/test_checkpoint_saving.py b/gpu_tests/test_checkpoint_saving.py index bec481a0f..94a1ecc41 100644 --- a/gpu_tests/test_checkpoint_saving.py +++ b/gpu_tests/test_checkpoint_saving.py @@ -90,7 +90,9 @@ def test_checkpoint_saving_and_uploading(self): "96.0", ], ) - self.assertTrue(os.path.basename(worker_cmd["command"][-1]) in expected_file_names) + self.assertTrue( + os.path.basename(worker_cmd["command"][-1]) in expected_file_names + ) self.assertEqual( worker_cmd["checkpoint_model_dir"], common_checkpoint_model_dir ) From 32b32b0791ac492f0e56464f98d6ca42dcffa827 Mon Sep 17 00:00:00 2001 From: David Esiobu Date: Thu, 30 Mar 2023 10:59:41 -0700 Subject: [PATCH 12/12] Lint --- tests/test_cli_train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_cli_train.py b/tests/test_cli_train.py index b8eac55f8..89d2ceac1 100644 --- a/tests/test_cli_train.py +++ b/tests/test_cli_train.py @@ -67,10 +67,7 @@ def test_nfs_copy(self): ) def test_nfs_copy_with_symlinks(self): - with ( - tempfile.TemporaryDirectory() as local_dir, - tempfile.TemporaryDirectory() as nfs_dir, - ): + with tempfile.TemporaryDirectory() as local_dir, tempfile.TemporaryDirectory() as nfs_dir: checkpoint_path = os.path.join(local_dir, "checkpoint_10.pt") create_local_test_file(checkpoint_path)