Skip to content

Add grad_norm metrics #1143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchtitan/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
from torch.utils.tensorboard import SummaryWriter

from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config_manager import JobConfig
Expand Down Expand Up @@ -347,6 +348,7 @@ def log(
step: int,
global_avg_loss: float,
global_max_loss: float,
grad_norm: float,
extra_metrics: dict[str, Any] | None = None,
):
assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
Expand All @@ -372,6 +374,7 @@ def log(
metrics = {
"loss_metrics/global_avg_loss": global_avg_loss,
"loss_metrics/global_max_loss": global_max_loss,
"optimizer/grad_norm": grad_norm,
"throughput(tps)": tps,
"tflops": tflops,
"mfu(%)": mfu,
Expand All @@ -394,7 +397,7 @@ def log(
color = self.color
logger.info(
f"{color.red}step: {step:2} "
f"{color.green}loss: {global_avg_loss:7.4f} "
f"{color.green}loss: {global_avg_loss:7.4f} gnorm: {grad_norm:.2f} "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

I agree this metric would be helpful. Do you think it has to be in the terminal print, instead of staying in TB/W&B only for now?
If so, could you invent a new color in the Color class -- o/w it seems visually harder to recognize loss. Also nit: two spaces before after each column.

f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
f"({device_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}tps: {round(tps):,} "
Expand Down
17 changes: 11 additions & 6 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module

from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.metrics import (
build_metrics_processor,
ensure_pp_loss_visible,
ensure_pp_loss_visible
)
from torchtitan.config_manager import ConfigManager, JobConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.distributed import ParallelDims
from torchtitan.distributed import utils as dist_utils
from torchtitan.protocols.model_converter import build_model_converters
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.profiling import (
maybe_enable_memory_snapshot,
maybe_enable_profiling,
maybe_enable_profiling
)


Expand Down Expand Up @@ -357,7 +357,7 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
del pred
loss.backward()

dist_utils.clip_grad_norm_(
grad_norm = dist_utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
self.job_config.training.max_norm,
foreach=True,
Expand Down Expand Up @@ -386,7 +386,12 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
else:
global_avg_loss = global_max_loss = loss.detach().item()

self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
self.metrics_processor.log(
self.step,
global_avg_loss,
global_max_loss,
grad_norm.item(),
)

@record
def train(self):
Expand Down
Loading