@@ -272,7 +272,65 @@ def train(self, **kwargs) -> None:
272
272
self ._optimizer .step ()
273
273
self ._optimizer .zero_grad (set_to_none = True )
274
274
275
- return
275
+ # Update the number of steps when the weights are updated
276
+ self .global_step += 1
277
+
278
+ loss_to_log = running_loss .detach ().item () / num_tokens
279
+ pbar .update (1 )
280
+ pbar .set_description (
281
+ f"{ curr_epoch + 1 } |{ self .global_step } |Loss: { loss_to_log } "
282
+ )
283
+
284
+ if self .global_step % self ._log_every_n_steps == 0 :
285
+ time_per_step = time .perf_counter () - t0
286
+ log_dict = {
287
+ "loss" : loss_to_log ,
288
+ "lr" : get_lr (
289
+ (
290
+ self ._optimizer
291
+ if not self ._optimizer_in_bwd
292
+ else self ._optim_ckpt_wrapper
293
+ ),
294
+ ),
295
+ "tokens_per_second_per_gpu" : num_tokens / time_per_step ,
296
+ }
297
+ if self ._device .type != "cpu" and self ._log_peak_memory_stats :
298
+ log_dict .update (
299
+ training .get_memory_stats (device = self ._device )
300
+ )
301
+ if self ._clip_grad_norm is not None :
302
+ log_dict .update ({"grad_norm" : grad_norm })
303
+ self ._metric_logger .log_dict (
304
+ log_dict ,
305
+ step = self .global_step ,
306
+ )
307
+
308
+ # Reset running stats for the next step
309
+ running_loss = 0
310
+ num_tokens = 0
311
+ t0 = time .perf_counter ()
312
+
313
+ # Stop tracking CUDA memory now that active steps are complete
314
+ if (
315
+ curr_epoch == 0
316
+ and self .profiler_profile_memory
317
+ and idx
318
+ == self .profiler_wait_steps
319
+ + self .profiler_warmup_steps
320
+ + self .profiler_active_steps
321
+ and self ._device .type == "cuda"
322
+ ):
323
+ torch .cuda .memory ._record_memory_history (enabled = None )
324
+
325
+ # Step profiler
326
+ # Note that this is called within gradient accumulation block, hence
327
+ # will include multiple forward / backward passes if gradient accumulation > 1
328
+ self ._profiler .step ()
329
+
330
+ self .epochs_run += 1
331
+ self .save_checkpoint (epoch = curr_epoch )
332
+
333
+ self ._profiler .stop ()
276
334
277
335
def save_checkpoint (self , ** kwargs ) -> None :
278
336
"""
0 commit comments