Skip to content

Commit 09388f9

Browse files
committed
finished core training loop
1 parent 14af406 commit 09388f9

File tree

1 file changed

+84
-4
lines changed

1 file changed

+84
-4
lines changed

recipes/qat_single_device.py

+84-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
6+
import sys
77
import time
88

99
from typing import Any, Dict # List, Optional, Union
@@ -12,10 +12,12 @@
1212
import torch
1313
from omegaconf import DictConfig # ListConfig
1414

15-
from torchtune import training, utils # config, modules,
15+
from torchtune import config, training, utils # modules
1616
from torchtune.modules.loss import SFTLoss
1717
from torchtune.recipe_interfaces import FTRecipeInterface
1818

19+
from torchtune.training.lr_schedulers import get_lr
20+
1921
from tqdm import tqdm
2022

2123

@@ -268,11 +270,69 @@ def train(self, **kwargs) -> None:
268270
grad_norm = torch.nn.utils.clip_grad_norm_(
269271
self._model.parameters(),
270272
max_norm=float(self._clip_grad_norm),
271-
).full_tensor()
273+
)
272274
self._optimizer.step()
273275
self._optimizer.zero_grad(set_to_none=True)
274276

275-
return
277+
# Update the number of steps when the weights are updated
278+
self.global_step += 1
279+
280+
loss_to_log = running_loss.detach().item() / num_tokens
281+
pbar.update(1)
282+
pbar.set_description(
283+
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
284+
)
285+
286+
if self.global_step % self._log_every_n_steps == 0:
287+
time_per_step = time.perf_counter() - t0
288+
log_dict = {
289+
"loss": loss_to_log,
290+
"lr": get_lr(
291+
(
292+
self._optimizer
293+
if not self._optimizer_in_bwd
294+
else self._optim_ckpt_wrapper
295+
),
296+
),
297+
"tokens_per_second_per_gpu": num_tokens / time_per_step,
298+
}
299+
if self._device.type != "cpu" and self._log_peak_memory_stats:
300+
log_dict.update(
301+
training.get_memory_stats(device=self._device)
302+
)
303+
if self._clip_grad_norm is not None:
304+
log_dict.update({"grad_norm": grad_norm})
305+
self._metric_logger.log_dict(
306+
log_dict,
307+
step=self.global_step,
308+
)
309+
310+
# Reset running stats for the next step
311+
running_loss = 0
312+
num_tokens = 0
313+
t0 = time.perf_counter()
314+
315+
# Stop tracking CUDA memory now that active steps are complete
316+
if (
317+
curr_epoch == 0
318+
and self.profiler_profile_memory
319+
and idx
320+
== self.profiler_wait_steps
321+
+ self.profiler_warmup_steps
322+
+ self.profiler_active_steps
323+
and self._device.type == "cuda"
324+
):
325+
torch.cuda.memory._record_memory_history(enabled=None)
326+
327+
# Step profiler
328+
# Note that this is called within gradient accumulation block, hence
329+
# will include multiple forward / backward passes if gradient accumulation > 1
330+
self._profiler.step()
331+
332+
self.epochs_run += 1
333+
self.save_checkpoint(epoch=curr_epoch)
334+
335+
self._profiler.stop()
276336

277337
def save_checkpoint(self, **kwargs) -> None:
278338
"""
@@ -287,3 +347,23 @@ def cleanup(self, **kwargs) -> None:
287347
Any cleaning up needed for the recipe.
288348
"""
289349
return
350+
351+
352+
@config.parse
353+
def recipe_main(cfg: DictConfig) -> None:
354+
"""
355+
Entry point for the recipe.
356+
357+
Configurable parameters are read in the following order:
358+
- Parameters specified in config (see available configs through ``tune ls``)
359+
- Overwritten by arguments from the command-line
360+
"""
361+
config.log_config(recipe_name="QATRecipeSingleDevice", cfg=cfg)
362+
recipe = QATRecipeSingleDevice(cfg=cfg)
363+
recipe.setup(cfg=cfg)
364+
recipe.train()
365+
recipe.cleanup()
366+
367+
368+
if __name__ == "__main__":
369+
sys.exit(recipe_main())

0 commit comments

Comments
 (0)