Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions metaseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,56 @@ def save_checkpoint(
async_callback_fn=async_callback_fn if save_to_NFS else None,
)

for cp in checkpoints[1:]:
assert PathManager.copy(checkpoints[0], cp, overwrite=True), f"Failed to copy {checkpoints[0]} to {cp}"

write_timer.stop()
logger.info(
f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) "
f"(writing took {write_timer.sum} seconds)"
)

delete_old_checkpoint_files(cfg, end_of_epoch, suffix, trainer.is_data_parallel_master)


def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str, is_data_parallel_master: bool):
if not end_of_epoch and cfg.keep_last_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)

if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)


def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
"""Retrieves all checkpoints found in `path` directory.

Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)

entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = float(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]


def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
"""
Expand Down Expand Up @@ -547,16 +591,16 @@ def load_model_ensemble_and_task(
f"!!! cfg does not exist in state keys = {state.keys()} !!!"
)

# Load 175B model trained on megatron (model parallel) branch
# "cfg.common.model_parallel_size == 1" checks if model parallel is
# enabled at load time. If it's not, fall back to non-MP
# transformer code path.
if (
getattr(cfg.model, "arch", None) == "transformer_lm_megatron"
and cfg.common.model_parallel_size == 1
):
cfg.model.arch = "transformer_lm_gpt"
cfg.model._name = "transformer_lm_gpt"
# # Load 175B model trained on megatron (model parallel) branch
# # "cfg.common.model_parallel_size == 1" checks if model parallel is
# # enabled at load time. If it's not, fall back to non-MP
# # transformer code path.
# if (
# getattr(cfg.model, "arch", None) == "transformer_lm_megatron"
# and cfg.common.model_parallel_size == 1
# ):
# cfg.model.arch = "transformer_lm_gpt"
# cfg.model._name = "transformer_lm_gpt"

# We now copy embed_tokens over to output_proj (if its missing) for all arches (only OPT here so far).
oproj_key = "decoder.output_projection.weight"
Expand Down
25 changes: 17 additions & 8 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,33 @@ def main(cfg: DictConfig) -> None:
# Print args
logger.info(cfg)

# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)

assert cfg.criterion, "Please specify criterion to train a model"

# Build model and criterion
if cfg.distributed_training.ddp_backend == "fully_sharded":
extra = {
"use_sharded_state": cfg.distributed_training.use_sharded_state,
}
# Build task, model and criterion
extra = {"use_sharded_state": cfg.distributed_training.use_sharded_state,}
if cfg.distributed_training.task_ddp_backend == "fully_sharded":
# As the task is non-trainable, we witch flags to more optimized ones.
memory_efficient_fp16 = cfg.distributed_training.memory_efficient_fp16
fp32_reduce_scatter = cfg.distributed_training.fp32_reduce_scatter
cfg.distributed_training.memory_efficient_fp16 = cfg.distributed_training.fp16
cfg.distributed_training.fp32_reduce_scatter = not cfg.distributed_training.fp16
with fsdp_enable_wrap(cfg.distributed_training, **extra):
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)
cfg.distributed_training.memory_efficient_fp16 = memory_efficient_fp16
cfg.distributed_training.fp32_reduce_scatter = fp32_reduce_scatter
else:
task = tasks.setup_task(cfg.task)

if cfg.distributed_training.ddp_backend == "fully_sharded":
with fsdp_enable_wrap(cfg.distributed_training, **extra):
model = fsdp_wrap(
task.build_model(cfg.model),
process_group=distributed_utils.get_data_parallel_group(),
)
else:
model = task.build_model(cfg.model)

# TODO[Susan]: FSDP on criterion?
criterion = task.build_criterion(cfg.criterion)

Expand Down
3 changes: 3 additions & 0 deletions metaseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ class DistributedTrainingConfig(MetaseqDataclass):
ddp_backend: DDP_BACKEND_CHOICES = field(
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
)
task_ddp_backend: DDP_BACKEND_CHOICES = field(
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend for task"}
)
bucket_cap_mb: int = field(
default=25, metadata={"help": "bucket size for reduction"}
)
Expand Down
3 changes: 1 addition & 2 deletions metaseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def _build_ema(self):
if self.is_fsdp:
# Build FSDP model
extra = {
"is_moe": getattr(self.cfg.model, "moe_freq", 0) > 0,
"use_sharded_state": self.use_sharded_state,
}
with fsdp_enable_wrap(self.cfg.distributed_training, **extra):
Expand Down Expand Up @@ -1212,7 +1211,7 @@ def _prepare_sample(self, sample, is_dummy=False):
def lower_precision(t):
"""Converts a tensor to the desired dtype based on our cfg."""
if t.dtype is torch.float32:
if self.cfg.common.bf16 or self.cfg.bf16:
if self.cfg.common.bf16:
return t.bfloat16()
return t.half()
return t
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def do_setup(package_data):
"torch",
"tqdm",
"typing_extensions",
"einops",
"webdataset==0.1.103",
"matplotlib",
"pytorchvideo==0.1.5",
"wandb",
"albumentations",
"dalle_pytorch",
],
dependency_links=dependency_links,
packages=find_packages(
Expand Down