Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,10 +1727,10 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):

norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
if norm_is_inf or norm_is_nan:
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)

return total_norm

# creates a flat fused tensor from the tensor list starting at the first_offset
Expand Down Expand Up @@ -1987,8 +1987,10 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale

# handle total_norm invalid value -1
if clip > 1:
combined_scale = clip * self.loss_scale

for grad in grad_groups_flat:
if isinstance(grad, list):
Expand Down