Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Merged
31 changes: 16 additions & 15 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,23 +315,24 @@ def train(
end_of_epoch = not itr.has_next()
if end_of_epoch:
grank = distributed_utils.get_global_rank()

log_seq = [f"End of Epoch on rank {grank}:"]
if hasattr(itr, "sequences_consumed"):
log_seq += [f"sequences_consumed={itr.sequences_consumed}"]
log_seq += [f"n={itr.n}"]

dataset = epoch_itr.dataset
while not hasattr(dataset, "len_cache"):
while not hasattr(dataset, "len_cache") and hasattr(dataset, "dataset"):
dataset = dataset.dataset
len_cache = tuple(dataset.len_cache.data)
cache_hash = hash(len_cache)
contains_zero = any([x == 0 for x in len_cache])
logger.warning(
" ".join(
[
f"End of Epoch on rank {grank}:",
f"sequences_consumed={itr.sequences_consumed}",
f"n={itr.n}",
f"len_cache_hash={cache_hash}",
f"len_cache_has_zeros={contains_zero}",
]
)
)
if hasattr(dataset, "len_cache"):
len_cache = tuple(dataset.len_cache.data)
cache_hash = hash(len_cache)
contains_zero = any([x == 0 for x in len_cache])
log_seq += [
f"len_cache_hash={cache_hash}",
f"len_cache_has_zeros={contains_zero}",
]
logger.warning(" ".join(log_seq))

valid_losses, should_stop = validate_and_save(
cfg,
Expand Down
14 changes: 12 additions & 2 deletions metaseq/distributed/stitch_fsdp_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def consolidate_fsdp_shards(
new_arch_name=None,
no_stitch_megatron=False,
megatron_part=None,
is_ema=False,
) -> str:
if pth_prefix.endswith(".pt"):
pth_prefix = pth_prefix[:-3]
Expand Down Expand Up @@ -68,7 +69,16 @@ def consolidate_fsdp_shards(
expert_dest_paths.append(f"{save_prefix}-rank-{r}.pt")
else:
ckpt = load_and_pop_last_optimizer_state(p)
weights.append(ckpt["model"])
if "ema_fp32_params" in ckpt["extra_state"]:
ema_key = "ema_fp32_params"
elif "ema" in ckpt["extra_state"]:
ema_key = "ema"
else:
ema_key = None
if is_ema and ema_key is not None:
weights.append(ckpt["extra_state"][ema_key])
else:
weights.append(ckpt["model"])
metadata.append(ckpt["shard_metadata"])
assert weights, f"all files were considered experts: {all_ckpt_files}"
do_consolidate = True
Expand Down Expand Up @@ -185,7 +195,7 @@ def consolidate_model_parallel(
all_parts_consolidated[k] = part_weights
if no_stitch_megatron:
return all_parts_consolidated
# glue to be a single megatron mdoel part
# glue to be a single megatron model part
model = reshard_megatron_parts(all_parts_consolidated, new_model_part_count=1)[0]
return model

Expand Down
2 changes: 1 addition & 1 deletion metaseq/models/ema/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, model, config, device=None):
self.decay = config.ema_decay
if isinstance(model, FullyShardedDataParallel):
self.model = model
logger.warning("EMA got FSDP model, assuming assigned model is a " "copy")
logger.info("EMA got FSDP model, assuming assigned model is a " "copy")
else:
self.model = copy.deepcopy(model)
self.model.requires_grad_(False)
Expand Down
4 changes: 3 additions & 1 deletion metaseq/optim/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def multiply_grads(self, c):
c = c.to(p.grad.device)
p.grad.data.mul_(c)

def clip_grad_norm(self, max_norm, norm_type="l2", aggregate_norm_fn=None):
def clip_grad_norm(
self, max_norm, norm_type="l2", aggregate_norm_fn=None, **kwargs
):
"""Clips gradient norm."""
return utils.clip_grad_norm_(
self.params, max_norm, norm_type, aggregate_norm_fn
Expand Down
2 changes: 1 addition & 1 deletion metaseq/scripts/reshard_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def reshard_fsdp_optim_state(
[_maybe_type(s["state"][idx][key], dtype) for s in shard_optim_states]
)
unpadded_value = _unpad_tensor(
tensor=unsharded_value,
shard=unsharded_value,
pad=shard_optim_padding.get(key, 0) if shard_optim_padding else 0,
)
chunks, _ = _shard_and_pad_tensor(unpadded_value, num_output_shards)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def do_setup(package_data):
"albumentations",
"dalle_pytorch",
"einops",
"matplotlib",
"matplotlib==3.5.0",
"pytorchvideo==0.1.5",
"wandb",
"webdataset==0.1.103",
Expand Down