Skip to content

Commit c0a3bf8

Browse files
Fix issue #5242 grad_norm and loss is nan (#7171)
This PR addresses a regression introduced in commit [61daaa1](61daaa1) that affects gradient clipping when handling infinite values. The modified NaN/Inf handling logic in total_norm calculation leads to unexpected behavior: Original logic ([v0.10.3](https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1233)): Converted both NaN and Inf to -1 before entering unscale_and_clip_grads Post-commit behavior: When total_norm is Inf, inf_or_nan.logical_not() * total_norm produces NaN instead of 0, causing gradient clipping to fail Here is a minimal reproducible example comparing gradient clipping behavior across implementations. ```python import torch import numpy as np import copy def test(total_norm): test_old_deepspeed(total_norm) test_deepspeed(total_norm) test_torch(total_norm) test_deepspeed_fix(total_norm) def test_old_deepspeed(total_norm_tensor): total_norm = copy.deepcopy(total_norm_tensor) # https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1233 if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: total_norm = torch.tensor(float(-1)) # https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1848 clip_grad = float(1.0) loss_scale = float(1.0) combined_scale = loss_scale clip = ((total_norm / loss_scale) + 1e-6) / clip_grad if clip > 1: combined_scale = clip * loss_scale print(f"old_deepspeed: {1. / combined_scale}") def test_deepspeed(total_norm_tensor): total_norm = copy.deepcopy(total_norm_tensor) # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1710 norm_is_inf = total_norm.isinf() norm_is_nan = total_norm.isnan() inf_or_nan = norm_is_nan.logical_or(norm_is_inf) err = torch.tensor(-1.0, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1970 clip_grad = float(1.0) loss_scale = float(1.0) clip = ((total_norm / loss_scale) + 1e-6) / clip_grad clip = torch.clamp(clip, min=1.0) combined_scale = clip * loss_scale print(f"test_deepspeed: {1. / combined_scale}") def test_torch(total_norm_tensor): # https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/utils/clip_grad.py#L155 total_norm = copy.deepcopy(total_norm_tensor) max_norm = float(1.0) clip_coef = max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) print(f"torch: {clip_coef_clamped}") def test_deepspeed_fix(total_norm_tensor): total_norm = copy.deepcopy(total_norm_tensor) if total_norm.isinf() or total_norm.isnan(): total_norm = torch.tensor(-1.0, dtype=torch.float) # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1970 clip_grad = float(1.0) loss_scale = float(1.0) clip = ((total_norm / loss_scale) + 1e-6) / clip_grad clip = torch.clamp(clip, min=1.0) combined_scale = clip * loss_scale print(f"test_deepspeed_fix: {1. / combined_scale}") if __name__ == '__main__': print("*****NAN*****") test(torch.tensor(float('nan'))) print("*****INF*****") test(torch.tensor(float('inf'))) print("*****positive*****") test(torch.tensor(float(2.0))) ``` Result: ![20250325165135](https://github.com/user-attachments/assets/bd32209d-14f6-4c21-8b57-f8bd94786fe2) --------- Signed-off-by: yueyang.hyy <[email protected]> Co-authored-by: Hongwei Chen <[email protected]>
1 parent f355b9e commit c0a3bf8

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1727,10 +1727,10 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
17271727

17281728
norm_is_inf = total_norm.isinf()
17291729
norm_is_nan = total_norm.isnan()
1730-
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
17311730

1732-
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
1733-
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
1731+
if norm_is_inf or norm_is_nan:
1732+
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)
1733+
17341734
return total_norm
17351735

17361736
# creates a flat fused tensor from the tensor list starting at the first_offset
@@ -1987,8 +1987,10 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
19871987
if self.clip_grad > 0.:
19881988
# norm is in fact norm*scale
19891989
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
1990-
clip = torch.clamp(clip, min=1.0)
1991-
combined_scale = clip * self.loss_scale
1990+
1991+
# handle total_norm invalid value -1
1992+
if clip > 1:
1993+
combined_scale = clip * self.loss_scale
19921994

19931995
for grad in grad_groups_flat:
19941996
if isinstance(grad, list):

0 commit comments

Comments
 (0)