File tree 1 file changed +10
-0
lines changed
1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -198,6 +198,16 @@ def _parse_losses(self, losses):
198
198
loss = sum (_value for _key , _value in log_vars .items ()
199
199
if 'loss' in _key )
200
200
201
+ # If the loss_vars has different length, GPUs will wait infinitely
202
+ if dist .is_available () and dist .is_initialized ():
203
+ log_var_length = torch .tensor (len (log_vars ), device = loss .device )
204
+ dist .all_reduce (log_var_length )
205
+ message = (f'rank { dist .get_rank ()} ' +
206
+ f' len(log_vars): { len (log_vars )} ' + ' keys: ' +
207
+ ',' .join (log_vars .keys ()))
208
+ assert log_var_length == len (log_vars ) * dist .get_world_size (), \
209
+ 'loss log variables are different across GPUs!\n ' + message
210
+
201
211
log_vars ['loss' ] = loss
202
212
for loss_name , loss_value in log_vars .items ():
203
213
# reduce loss when distributed training
You can’t perform that action at this time.
0 commit comments