Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
44 changes: 28 additions & 16 deletions gpu_tests/test_checkpoint_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import partial, partialmethod
import unittest
from unittest.mock import patch, Mock, MagicMock
from urllib.parse import urlparse
import torch
from metaseq.dataclass.configs import DistributedTrainingConfig
from metaseq.launcher.opt_baselines import cli_main as sweep_cli_main
Expand Down Expand Up @@ -81,15 +82,17 @@ def test_checkpoint_saving_and_uploading(self):
self.assertEqual(file_names_saved_azure, expected_file_names)
for worker_cmd in upload_events:
self.assertEqual(
worker_cmd["command"],
worker_cmd["command"][:4],
[
"azcopy",
"copy",
"--cap-mbps",
"96.0",
"https://myaccount.blob.core.windows.net/test",
],
)
self.assertTrue(
os.path.basename(worker_cmd["command"][-1]) in expected_file_names
)
self.assertEqual(
worker_cmd["checkpoint_model_dir"], common_checkpoint_model_dir
)
Expand Down Expand Up @@ -235,26 +238,35 @@ def download_checkpoint_mock(blob_url, checkpoint_path, suffix, events):


def subprocess_run_mock(cmd, stdout, stderr, events):
# replaces subprocess.run azcopy command that uploads to azure
_, checkpoint_dir, checkpoint_model_dir, checkpoint_file = cmd[4].split("/")
events.append(
{
"type": "upload",
"command": cmd[:4] + cmd[5:],
"checkpoint_dir": checkpoint_dir,
"checkpoint_model_dir": checkpoint_model_dir,
"checkpoint_file": checkpoint_file,
"file_saved_locally": os.path.exists(cmd[4]),
}
)
source = cmd[4]
dest = cmd[5]

# Only interested in local -> remote transfers (not asserting remote copies/aliases)
if urlparse(source).scheme == "" and urlparse(dest).scheme == "https":
_, checkpoint_dir, checkpoint_model_dir, checkpoint_file = source.split("/")
events.append(
{
"type": "upload",
"command": cmd[:4] + cmd[5:],
"checkpoint_dir": checkpoint_dir,
"checkpoint_model_dir": checkpoint_model_dir,
"checkpoint_file": checkpoint_file,
"file_saved_locally": os.path.exists(source),
}
)

res = Mock()
res.returncode = 0
return res


def save_checkpoint_mock(
self, filename, extra_state, training_finished=False, async_callback_fn=None
self,
filename,
extra_state,
training_finished=False,
async_callback_fn=None,
files_to_symlink_to=None,
):
logger = logging.getLogger("metaseq.trainer")
"""Save all training state in a checkpoint file."""
Expand All @@ -278,7 +290,7 @@ def perform_save():
logger.info(f"Beginning asynchronous torch.save to {filename}")
torch.save(state_dict, filename)
if async_callback_fn is not None:
async_callback_fn(filename)
async_callback_fn(filename, files_to_symlink_to)
logger.info(f"Asynchronous torch.save to {filename} complete.")
except Exception as e:
logger.exception(f"Asynchronous save failed: {e}")
Expand Down
45 changes: 45 additions & 0 deletions metaseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def save_checkpoint(
extra_state,
training_finished=training_finished,
async_callback_fn=async_callback_fn if save_to_NFS else None,
files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None,
)

write_timer.stop()
Expand All @@ -103,6 +104,50 @@ def save_checkpoint(
f"(writing took {write_timer.sum} seconds)"
)

# See if there's any older checkpoints to delete after saving a new one.
# Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both).
delete_old_checkpoint_files(cfg, end_of_epoch, suffix)


def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str):
if not end_of_epoch and cfg.keep_last_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_updates :]:
if os.path.lexists(old_chk):
os.remove(old_chk)

if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)


# Reference:
# https://github.com/facebookresearch/fairseq/blob/0338cdc3094ca7d29ff4d36d64791f7b4e4b5e6e/fairseq/checkpoint_utils.py#L538
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)

entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = float(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]


def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
"""
Expand Down
123 changes: 90 additions & 33 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import argparse
from datetime import datetime
import functools
import logging
import math
Expand All @@ -18,6 +19,7 @@
import socket
import re
from typing import Dict, Optional, Any, List, Tuple, Callable
from urllib.parse import urlparse
import warnings

import numpy as np
Expand Down Expand Up @@ -511,37 +513,51 @@ def _checkpoint_add_directory(basename):
return m[1], f"checkpoint{m[3]}"


def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
def _get_basename(path):
res = urlparse(path)
if res.scheme:
return os.path.basename(res.path)
else:
return os.path.basename(path)


def _get_destination_path(path, destination):
"""Calculates the destination path with handling for remote paths."""
basename = _get_basename(path)
res = urlparse(destination)
if res.scheme:
new_path = os.path.join(res.path, basename)
res = res._replace(path=new_path)
return res.geturl()
else:
return os.path.join(destination, basename)


def post_checkpoint_callback(
cfg, num_updates, training_finished, filename, files_to_symlink_to
):
if cfg.checkpoint.cloud_upload_path is not None:
if "blob.core.windows.net" in cfg.checkpoint.cloud_upload_path:
azcopy_logs = filename + "_azcopy_logs"
os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10"
os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs
os.makedirs(azcopy_logs, exist_ok=True)
logger.info(
f"preparing to azcopy {filename} to {cfg.checkpoint.cloud_upload_path}; logs in {azcopy_logs}"
azcopy_log_dir = os.path.dirname(filename)
final_path = _get_destination_path(
filename, cfg.checkpoint.cloud_upload_path
)
cmd = [
"azcopy", # TODO(susanz): require azcopy to be installed.
"copy",
"--cap-mbps",
"96.0",
filename,
cfg.checkpoint.cloud_upload_path,
]
res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.returncode != 0:
print("Error: {}, azcopy failed".format(res.returncode))
print("Azcopy stdout = {}".format(res.stdout))
sys.exit(1)
_copy_to_azure(filename, final_path, azcopy_log_dir)

# Delete original checkpoint on local storage
# TODO make this configurable
logger.info(
f"Successfully copied {filename} to {cfg.checkpoint.cloud_upload_path}"
)
os.remove(filename)

# Azure Blob doesn't support symlinks so make full copies
if files_to_symlink_to:
for other_checkpoint in files_to_symlink_to:
dest = _get_destination_path(
other_checkpoint, cfg.checkpoint.cloud_upload_path
)
_copy_to_azure(final_path, dest, azcopy_log_dir)

elif cfg.checkpoint.cloud_upload_path.startswith("nfs:"):
path, basename = os.path.split(filename)
basename = os.path.basename(filename)
checkpoint_dir, checkpoint_file = _checkpoint_add_directory(basename)
destination_checkpoints_dir = cfg.checkpoint.cloud_upload_path[4:]
temporary_checkpoint_file = f"_{checkpoint_file}"
Expand All @@ -565,6 +581,9 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
)

logger.info(f"Renaming {temporary_checkpoint_file} -> {checkpoint_file}")
final_path = os.path.join(
destination_checkpoints_dir, checkpoint_dir, checkpoint_file
)
# atomic rename _checkpointfile -> checkpointfile
# this way we know that if present the checkpoint file is complete
os.rename(
Expand All @@ -573,12 +592,20 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
checkpoint_dir,
temporary_checkpoint_file,
),
os.path.join(
destination_checkpoints_dir, checkpoint_dir, checkpoint_file
),
final_path,
)
os.remove(filename)

if files_to_symlink_to:
dest_dir = os.path.dirname(final_path)
for other_checkpoint in files_to_symlink_to:
dest = _get_destination_path(other_checkpoint, dest_dir)
if PathManager.islink(dest):
PathManager.rm(dest)
assert PathManager.symlink(
final_path, dest
), f"Failed to symlink {final_path} to {dest}"

# Start running evals on uploaded checkpoint
nfs_evaluation(
cfg,
Expand All @@ -592,13 +619,18 @@ def post_checkpoint_callback(cfg, num_updates, training_finished, filename):
try:
# PathManager only supports writing to S3, but this function call
# can be replaced with other APIs for copying checkpoints.
PathManager.copy_from_local(
filename,
os.path.join(
cfg.checkpoint.cloud_upload_path, os.path.basename(filename)
),
overwrite=True,
final_path = _get_destination_path(
filename, cfg.checkpoint.cloud_upload_path
)
PathManager.copy_from_local(filename, final_path, overwrite=True)

# Some non-native PathHandlers don't support symlinks so default to full copies
if files_to_symlink_to:
for other_checkpoint in files_to_symlink_to:
dest = _get_destination_path(
other_checkpoint, cfg.checkpoint.cloud_upload_path
)
PathManager.copy(final_path, dest, overwrite=True)
except (FileNotFoundError, AssertionError) as e:
logger.info(f"could not upload {filename}: {e}")

Expand Down Expand Up @@ -664,6 +696,31 @@ def nfs_evaluation(
)


def _copy_to_azure(source, destination, log_dir):
# /dir/checkpoint_last.pt -> /dir/checkpoint_last.pt_azcopy_logs_2000-01-01T00_00_00
basename = _get_basename(destination)
timestamp = datetime.utcnow().isoformat().replace(":", "_")[:-7]
azcopy_logs = os.path.join(log_dir, f"{basename}_azcopy_logs_{timestamp}")
os.environ["AZCOPY_CONCURRENCY_VALUE"] = "10"
os.environ["AZCOPY_LOG_LOCATION"] = azcopy_logs
os.makedirs(azcopy_logs, exist_ok=True)
logger.info(f"preparing to azcopy {source} to {destination}; logs in {azcopy_logs}")
cmd = [
"azcopy", # TODO(susanz): require azcopy to be installed.
"copy",
"--cap-mbps",
"96.0",
source,
destination,
]
res = _run_azcopy(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.returncode != 0:
print("Error: {}, azcopy failed".format(res.returncode))
print("Azcopy stdout = {}".format(res.stdout))
sys.exit(1)
logger.info(f"Successfully copied {source} to {destination}")


def _run_azcopy(cmd, stdout, stderr):
return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

Expand Down
9 changes: 7 additions & 2 deletions metaseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,12 @@ def state_dict(self, filename, training_finished=False) -> Dict[str, Dict]:
return state_dicts

def save_checkpoint(
self, filename, extra_state, training_finished=False, async_callback_fn=None
self,
filename,
extra_state,
training_finished=False,
async_callback_fn=None,
files_to_symlink_to=None,
):
"""Save all training state in a checkpoint file."""

Expand All @@ -445,7 +450,7 @@ def save_checkpoint(
def perform_save():
try:
logger.info(f"Beginning asynchronous torch.save to {filename}")
async_callback_fn(filename)
async_callback_fn(filename, files_to_symlink_to)
logger.info(f"Asynchronous torch.save to {filename} complete.")
except Exception as e:
logger.exception(f"Asynchronous save failed: {e}")
Expand Down
Loading