13
13
14
14
import torch
15
15
import torchtune .modules .common_utils as common_utils
16
- from omegaconf import DictConfig , ListConfig
16
+ from omegaconf import DictConfig , ListConfig , OmegaConf
17
17
18
18
from torch import nn
19
19
from torch .optim import Optimizer
24
24
from torchtune .datasets import ConcatDataset
25
25
from torchtune .modules .peft import (
26
26
get_adapter_params ,
27
- get_adapter_state_dict ,
28
27
get_lora_module_names ,
29
- get_merged_lora_ckpt ,
30
28
set_trainable_params ,
31
29
validate_missing_and_unexpected_for_lora ,
32
30
)
33
31
from torchtune .recipe_interfaces import FTRecipeInterface
34
32
from torchtune .training import DummyProfiler , PROFILER_KEY
33
+ from torchtune .training .checkpointing ._checkpoint_client import (
34
+ CheckpointClient ,
35
+ TrainingProgress ,
36
+ )
35
37
from tqdm import tqdm
36
38
37
39
@@ -140,36 +142,21 @@ def __init__(self, cfg: DictConfig) -> None:
140
142
self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
141
143
self ._kd_ratio = cfg .get ("kd_ratio" , 0.5 )
142
144
143
- def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
144
- """
145
- Extract the checkpoint state from file and validate. This includes the
146
- base model weights. If resume_from_checkpoint is True, this also includes
147
- the adapter weights and recipe state
148
- """
149
- self ._checkpointer = config .instantiate (
150
- cfg_checkpointer ,
151
- should_load_recipe_state = self ._resume_from_checkpoint ,
152
- )
153
- checkpoint_dict = self ._checkpointer .load_checkpoint ()
154
-
155
- if self ._resume_from_checkpoint :
156
- if training .ADAPTER_KEY not in checkpoint_dict :
157
- raise ValueError (
158
- "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
159
- )
160
- # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
161
- # no need to check here
162
- self ._update_recipe_state (checkpoint_dict )
163
- return checkpoint_dict
145
+ self ._checkpoint_client = CheckpointClient (cfg )
146
+ self ._enable_async_checkpointing = cfg .get ("enable_async_checkpointing" , False )
164
147
165
148
def load_teacher_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
166
149
"""
167
150
Extract the teacher checkpoint state from file.
168
151
"""
169
- teacher_checkpointer = config .instantiate (
170
- cfg_checkpointer ,
152
+ # add checkpointer class to config to work with checkpoint_client
153
+ checkpointer_dict = {"checkpointer" : cfg_checkpointer }
154
+
155
+ new_cfg_checkpointer = OmegaConf .create (checkpointer_dict )
156
+ teacher_checkpoint_client = CheckpointClient (
157
+ new_cfg_checkpointer ,
171
158
)
172
- checkpoint_dict = teacher_checkpointer . load_checkpoint ()
159
+ checkpoint_dict = teacher_checkpoint_client . load_base_checkpoint ()
173
160
return checkpoint_dict
174
161
175
162
def _update_recipe_state (self , ckpt_dict : Dict [str , Any ]) -> None :
@@ -220,7 +207,7 @@ def setup(self, cfg: DictConfig) -> None:
220
207
221
208
self ._metric_logger = config .instantiate (cfg .metric_logger )
222
209
223
- ckpt_dict = self .load_checkpoint ( cfg . checkpointer )
210
+ checkpoint_dict = self ._checkpoint_client . load_base_checkpoint ( )
224
211
225
212
# log config with parameter override
226
213
self ._metric_logger .log_config (cfg )
@@ -230,7 +217,7 @@ def setup(self, cfg: DictConfig) -> None:
230
217
raise ValueError (
231
218
"NPU does not support model compilation. Please set `compile: False` in the config."
232
219
)
233
- checkpoint_dict = self . load_checkpoint ( cfg_checkpointer = cfg . checkpointer )
220
+
234
221
teacher_checkpoint_dict = self .load_teacher_checkpoint (
235
222
cfg_checkpointer = cfg .teacher_checkpointer
236
223
)
@@ -245,7 +232,7 @@ def setup(self, cfg: DictConfig) -> None:
245
232
base_model_state_dict = checkpoint_dict [training .MODEL_KEY ],
246
233
lora_weights_state_dict = (
247
234
checkpoint_dict [training .ADAPTER_KEY ]
248
- if self . _resume_from_checkpoint
235
+ if training . ADAPTER_KEY in checkpoint_dict
249
236
else None
250
237
),
251
238
)
@@ -255,6 +242,27 @@ def setup(self, cfg: DictConfig) -> None:
255
242
model_state_dict = teacher_checkpoint_dict [training .MODEL_KEY ],
256
243
)
257
244
245
+ if self ._resume_from_checkpoint :
246
+ # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
247
+ # using the DistributedCheckpointer.
248
+ # Therefore the recipe needs to load the distributed checkpoint to restore the training
249
+ # progress.
250
+ if self ._enable_async_checkpointing :
251
+ checkpoint_dict = self ._checkpoint_client .load_distributed_checkpoint (
252
+ self ._model ,
253
+ self ._optimizer ,
254
+ self ._adapter_config ,
255
+ self ._save_adapter_weights_only ,
256
+ )
257
+
258
+ if training .ADAPTER_KEY not in checkpoint_dict :
259
+ raise ValueError (
260
+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
261
+ )
262
+
263
+ # Update the recipe state from the checkpoint state dict.
264
+ self ._update_recipe_state (checkpoint_dict )
265
+
258
266
self ._tokenizer = config .instantiate (cfg .tokenizer )
259
267
self ._logger .info ("Tokenizer is initialized from file." )
260
268
@@ -294,11 +302,6 @@ def setup(self, cfg: DictConfig) -> None:
294
302
cfg_dataset = cfg .dataset ,
295
303
batch_size = cfg .batch_size ,
296
304
shuffle = cfg .shuffle ,
297
- dataloader_state_dict = (
298
- ckpt_dict [training .DATALOADER_KEY ]
299
- if self ._resume_from_checkpoint
300
- else None
301
- ),
302
305
)
303
306
304
307
# Finally update the recipe state which can only be correctly set after all of the
@@ -378,6 +381,17 @@ def _setup_model(
378
381
self ._lora_attn_modules = list (cfg_model .lora_attn_modules )
379
382
self ._apply_lora_to_mlp = cfg_model .apply_lora_to_mlp
380
383
self ._apply_lora_to_output = getattr (cfg_model , "apply_lora_to_output" , False )
384
+ self ._adapter_config = {
385
+ "r" : self ._lora_rank ,
386
+ "lora_alpha" : self ._lora_alpha ,
387
+ "target_modules" : get_lora_module_names (
388
+ self ._lora_attn_modules ,
389
+ self ._apply_lora_to_mlp ,
390
+ self ._apply_lora_to_output ,
391
+ ),
392
+ "peft_type" : "LORA" ,
393
+ }
394
+
381
395
self .adapter_params = get_adapter_params (model )
382
396
self ._is_dora = any (["magnitude" in k for k in self .adapter_params .keys ()])
383
397
set_trainable_params (model , self .adapter_params )
@@ -492,7 +506,6 @@ def _setup_data(
492
506
cfg_dataset : DictConfig ,
493
507
shuffle : bool ,
494
508
batch_size : int ,
495
- dataloader_state_dict : Optional [Dict [str , Any ]] = None ,
496
509
) -> StatefulDataLoader :
497
510
"""
498
511
All data related setup happens here. This recipe currently supports only
@@ -537,63 +550,20 @@ def _setup_data(
537
550
return dataloader
538
551
539
552
def save_checkpoint (self , epoch : int ) -> None :
540
- """
541
- Checkpoint the state of the recipe. The constructed checkpoint state dict
542
- contains the following information:
543
- - Merged weights with key MODEL_KEY
544
- - Adapter weights with key ADAPTER_KEY
545
- - Relevant recipe state if training is not complete
546
- - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
547
-
548
- To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
549
- """
550
- ckpt_dict = {}
551
-
552
- intermediate_checkpoint = epoch + 1 < self .total_epochs
553
- # if training is in-progress, checkpoint the optimizer state as well
554
- if intermediate_checkpoint :
555
- ckpt_dict .update (
556
- {
557
- training .OPT_KEY : self ._optimizer .state_dict (),
558
- training .SEED_KEY : self .seed ,
559
- training .EPOCHS_KEY : self .epochs_run ,
560
- training .TOTAL_EPOCHS_KEY : self .total_epochs ,
561
- training .MAX_STEPS_KEY : self .max_steps_per_epoch ,
562
- training .DATALOADER_KEY : self ._dataloader .state_dict (),
563
- }
564
- )
565
-
566
- # Move to CPU to avoid a copy on GPU
567
- state_dict = {k : v .cpu () for k , v in self ._model .state_dict ().items ()}
568
-
569
- # Construct the full state dict with LoRA weights merged into base LLM weights
570
- merged_state_dict = get_merged_lora_ckpt (
571
- state_dict ,
572
- rank = self ._lora_rank ,
573
- alpha = self ._lora_alpha ,
574
- )
575
- ckpt_dict .update ({training .MODEL_KEY : merged_state_dict })
576
-
577
- # Construct the adapter weights
578
- adapter_state_dict = get_adapter_state_dict (self ._model .state_dict ())
579
- ckpt_dict .update ({training .ADAPTER_KEY : adapter_state_dict })
580
- adapter_config = {
581
- "r" : self ._lora_rank ,
582
- "lora_alpha" : self ._lora_alpha ,
583
- "target_modules" : get_lora_module_names (
584
- self ._lora_attn_modules ,
585
- self ._apply_lora_to_mlp ,
586
- self ._apply_lora_to_output ,
553
+ self ._checkpoint_client .save_checkpoint (
554
+ model = self ._model ,
555
+ optimizer = self ._optimizer ,
556
+ training_progress = TrainingProgress (
557
+ seed = self .seed ,
558
+ epochs_run = self .epochs_run ,
559
+ total_epochs = self .total_epochs ,
560
+ max_steps_per_epoch = self .max_steps_per_epoch ,
561
+ dataloader_state_dict = self ._dataloader .state_dict (),
587
562
),
588
- "peft_type" : "LORA" ,
589
- }
590
- ckpt_dict .update ({training .ADAPTER_CONFIG : adapter_config })
591
-
592
- self ._checkpointer .save_checkpoint (
593
- ckpt_dict ,
594
563
epoch = epoch ,
595
- intermediate_checkpoint = intermediate_checkpoint ,
564
+ adapter_config = self . _adapter_config . copy () ,
596
565
adapter_only = self ._save_adapter_weights_only ,
566
+ single_device = True ,
597
567
)
598
568
599
569
def _loss_step (
0 commit comments