Skip to content

Commit 7a9b890

Browse files
committed
handle total_norm invalid value
1 parent 1768d55 commit 7a9b890

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1987,8 +1987,10 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
19871987
if self.clip_grad > 0.:
19881988
# norm is in fact norm*scale
19891989
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
1990-
clip = torch.clamp(clip, min=1.0)
1991-
combined_scale = clip * self.loss_scale
1990+
1991+
# handle total_norm invalid value -1
1992+
if clip > 1:
1993+
combined_scale = clip * self.loss_scale
19921994

19931995
for grad in grad_groups_flat:
19941996
if isinstance(grad, list):

0 commit comments

Comments
 (0)