diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 72ab5b8f1..5fc2a9f65 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -12,6 +12,7 @@ import torch from torch.utils.tensorboard import SummaryWriter + from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config_manager import JobConfig @@ -347,6 +348,7 @@ def log( step: int, global_avg_loss: float, global_max_loss: float, + grad_norm: float, extra_metrics: dict[str, Any] | None = None, ): assert self.num_flops_per_token > 0, "num_flops_per_token must be set" @@ -372,6 +374,7 @@ def log( metrics = { "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, + "optimizer/grad_norm": grad_norm, "throughput(tps)": tps, "tflops": tflops, "mfu(%)": mfu, @@ -394,7 +397,7 @@ def log( color = self.color logger.info( f"{color.red}step: {step:2} " - f"{color.green}loss: {global_avg_loss:7.4f} " + f"{color.green}loss: {global_avg_loss:7.4f} gnorm: {grad_norm:.2f} " f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " diff --git a/torchtitan/train.py b/torchtitan/train.py index 375a41260..6f8cdc79e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -15,20 +15,20 @@ import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module - from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.metrics import ( build_metrics_processor, - ensure_pp_loss_visible, + ensure_pp_loss_visible ) from torchtitan.config_manager import ConfigManager, JobConfig -from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed import ParallelDims +from torchtitan.distributed import utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, - maybe_enable_profiling, + maybe_enable_profiling ) @@ -357,7 +357,7 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): del pred loss.backward() - dist_utils.clip_grad_norm_( + grad_norm = dist_utils.clip_grad_norm_( [p for m in model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, @@ -386,7 +386,12 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): else: global_avg_loss = global_max_loss = loss.detach().item() - self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) + self.metrics_processor.log( + self.step, + global_avg_loss, + global_max_loss, + grad_norm.item(), + ) @record def train(self):