3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
-
6
+ import sys
7
7
import time
8
8
9
9
from typing import Any , Dict # List, Optional, Union
12
12
import torch
13
13
from omegaconf import DictConfig # ListConfig
14
14
15
- from torchtune import training , utils # config, modules,
15
+ from torchtune import config , training , utils # modules
16
16
from torchtune .modules .loss import SFTLoss
17
17
from torchtune .recipe_interfaces import FTRecipeInterface
18
18
19
+ from torchtune .training .lr_schedulers import get_lr
20
+
19
21
from tqdm import tqdm
20
22
21
23
@@ -268,11 +270,69 @@ def train(self, **kwargs) -> None:
268
270
grad_norm = torch .nn .utils .clip_grad_norm_ (
269
271
self ._model .parameters (),
270
272
max_norm = float (self ._clip_grad_norm ),
271
- ). full_tensor ()
273
+ )
272
274
self ._optimizer .step ()
273
275
self ._optimizer .zero_grad (set_to_none = True )
274
276
275
- return
277
+ # Update the number of steps when the weights are updated
278
+ self .global_step += 1
279
+
280
+ loss_to_log = running_loss .detach ().item () / num_tokens
281
+ pbar .update (1 )
282
+ pbar .set_description (
283
+ f"{ curr_epoch + 1 } |{ self .global_step } |Loss: { loss_to_log } "
284
+ )
285
+
286
+ if self .global_step % self ._log_every_n_steps == 0 :
287
+ time_per_step = time .perf_counter () - t0
288
+ log_dict = {
289
+ "loss" : loss_to_log ,
290
+ "lr" : get_lr (
291
+ (
292
+ self ._optimizer
293
+ if not self ._optimizer_in_bwd
294
+ else self ._optim_ckpt_wrapper
295
+ ),
296
+ ),
297
+ "tokens_per_second_per_gpu" : num_tokens / time_per_step ,
298
+ }
299
+ if self ._device .type != "cpu" and self ._log_peak_memory_stats :
300
+ log_dict .update (
301
+ training .get_memory_stats (device = self ._device )
302
+ )
303
+ if self ._clip_grad_norm is not None :
304
+ log_dict .update ({"grad_norm" : grad_norm })
305
+ self ._metric_logger .log_dict (
306
+ log_dict ,
307
+ step = self .global_step ,
308
+ )
309
+
310
+ # Reset running stats for the next step
311
+ running_loss = 0
312
+ num_tokens = 0
313
+ t0 = time .perf_counter ()
314
+
315
+ # Stop tracking CUDA memory now that active steps are complete
316
+ if (
317
+ curr_epoch == 0
318
+ and self .profiler_profile_memory
319
+ and idx
320
+ == self .profiler_wait_steps
321
+ + self .profiler_warmup_steps
322
+ + self .profiler_active_steps
323
+ and self ._device .type == "cuda"
324
+ ):
325
+ torch .cuda .memory ._record_memory_history (enabled = None )
326
+
327
+ # Step profiler
328
+ # Note that this is called within gradient accumulation block, hence
329
+ # will include multiple forward / backward passes if gradient accumulation > 1
330
+ self ._profiler .step ()
331
+
332
+ self .epochs_run += 1
333
+ self .save_checkpoint (epoch = curr_epoch )
334
+
335
+ self ._profiler .stop ()
276
336
277
337
def save_checkpoint (self , ** kwargs ) -> None :
278
338
"""
@@ -287,3 +347,23 @@ def cleanup(self, **kwargs) -> None:
287
347
Any cleaning up needed for the recipe.
288
348
"""
289
349
return
350
+
351
+
352
+ @config .parse
353
+ def recipe_main (cfg : DictConfig ) -> None :
354
+ """
355
+ Entry point for the recipe.
356
+
357
+ Configurable parameters are read in the following order:
358
+ - Parameters specified in config (see available configs through ``tune ls``)
359
+ - Overwritten by arguments from the command-line
360
+ """
361
+ config .log_config (recipe_name = "QATRecipeSingleDevice" , cfg = cfg )
362
+ recipe = QATRecipeSingleDevice (cfg = cfg )
363
+ recipe .setup (cfg = cfg )
364
+ recipe .train ()
365
+ recipe .cleanup ()
366
+
367
+
368
+ if __name__ == "__main__" :
369
+ sys .exit (recipe_main ())
0 commit comments