Skip to content

Commit b3370bb

Browse files
committed
finished core training loop
1 parent 14af406 commit b3370bb

File tree

1 file changed

+59
-1
lines changed

1 file changed

+59
-1
lines changed

recipes/qat_single_device.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,65 @@ def train(self, **kwargs) -> None:
272272
self._optimizer.step()
273273
self._optimizer.zero_grad(set_to_none=True)
274274

275-
return
275+
# Update the number of steps when the weights are updated
276+
self.global_step += 1
277+
278+
loss_to_log = running_loss.detach().item() / num_tokens
279+
pbar.update(1)
280+
pbar.set_description(
281+
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
282+
)
283+
284+
if self.global_step % self._log_every_n_steps == 0:
285+
time_per_step = time.perf_counter() - t0
286+
log_dict = {
287+
"loss": loss_to_log,
288+
"lr": get_lr(
289+
(
290+
self._optimizer
291+
if not self._optimizer_in_bwd
292+
else self._optim_ckpt_wrapper
293+
),
294+
),
295+
"tokens_per_second_per_gpu": num_tokens / time_per_step,
296+
}
297+
if self._device.type != "cpu" and self._log_peak_memory_stats:
298+
log_dict.update(
299+
training.get_memory_stats(device=self._device)
300+
)
301+
if self._clip_grad_norm is not None:
302+
log_dict.update({"grad_norm": grad_norm})
303+
self._metric_logger.log_dict(
304+
log_dict,
305+
step=self.global_step,
306+
)
307+
308+
# Reset running stats for the next step
309+
running_loss = 0
310+
num_tokens = 0
311+
t0 = time.perf_counter()
312+
313+
# Stop tracking CUDA memory now that active steps are complete
314+
if (
315+
curr_epoch == 0
316+
and self.profiler_profile_memory
317+
and idx
318+
== self.profiler_wait_steps
319+
+ self.profiler_warmup_steps
320+
+ self.profiler_active_steps
321+
and self._device.type == "cuda"
322+
):
323+
torch.cuda.memory._record_memory_history(enabled=None)
324+
325+
# Step profiler
326+
# Note that this is called within gradient accumulation block, hence
327+
# will include multiple forward / backward passes if gradient accumulation > 1
328+
self._profiler.step()
329+
330+
self.epochs_run += 1
331+
self.save_checkpoint(epoch=curr_epoch)
332+
333+
self._profiler.stop()
276334

277335
def save_checkpoint(self, **kwargs) -> None:
278336
"""

0 commit comments

Comments
 (0)