12
12
from warnings import warn
13
13
14
14
import torch
15
- from omegaconf import DictConfig , ListConfig
15
+ from omegaconf import DictConfig , ListConfig , OmegaConf
16
16
17
17
from torch import nn
18
18
from torch .distributed import destroy_process_group , init_process_group
26
26
from torchtune .modules .peft import (
27
27
AdapterModule ,
28
28
get_adapter_params ,
29
- get_adapter_state_dict ,
30
29
get_lora_module_names ,
31
- get_merged_lora_ckpt ,
32
30
set_trainable_params ,
33
31
validate_missing_and_unexpected_for_lora ,
34
32
)
35
33
from torchtune .recipe_interfaces import FTRecipeInterface
36
34
from torchtune .training import DummyProfiler , PROFILER_KEY
35
+ from torchtune .training .checkpointing ._checkpoint_client import (
36
+ CheckpointClient ,
37
+ TrainingProgress ,
38
+ )
37
39
38
40
from tqdm import tqdm
39
41
@@ -136,6 +138,9 @@ def __init__(self, cfg: DictConfig) -> None:
136
138
self ._log_peak_memory_stats = cfg .get ("log_peak_memory_stats" , False )
137
139
138
140
# training attributes
141
+ self ._enable_async_checkpointing = cfg .get ("enable_async_checkpointing" , False )
142
+ self ._checkpoint_client = CheckpointClient (cfg )
143
+
139
144
self ._enable_activation_checkpointing = cfg .enable_activation_checkpointing
140
145
141
146
# These are public properties which are updated by the checkpoint loader
@@ -153,36 +158,18 @@ def __init__(self, cfg: DictConfig) -> None:
153
158
self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
154
159
self ._kd_ratio = cfg .get ("kd_ratio" , 0.5 )
155
160
156
- def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
157
- """
158
- Extract the checkpoint state from file and validate. This includes the
159
- base model weights. If resume_from_checkpoint is True, this also includes
160
- the adapter weights and recipe state
161
- """
162
- self ._checkpointer = config .instantiate (
163
- cfg_checkpointer ,
164
- should_load_recipe_state = self ._resume_from_checkpoint ,
165
- )
166
- checkpoint_dict = self ._checkpointer .load_checkpoint ()
167
-
168
- if self ._resume_from_checkpoint :
169
- if training .ADAPTER_KEY not in checkpoint_dict :
170
- raise ValueError (
171
- "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
172
- )
173
- # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
174
- # no need to check here
175
- self ._update_recipe_state (checkpoint_dict )
176
- return checkpoint_dict
177
-
178
161
def load_teacher_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
179
162
"""
180
163
Extract the teacher checkpoint state from file.
181
164
"""
182
- teacher_checkpointer = config .instantiate (
183
- cfg_checkpointer ,
165
+ # add checkpointer class to config to work with checkpoint_client
166
+ checkpointer_dict = {"checkpointer" : cfg_checkpointer }
167
+
168
+ new_cfg_checkpointer = OmegaConf .create (checkpointer_dict )
169
+ teacher_checkpoint_client = CheckpointClient (
170
+ new_cfg_checkpointer ,
184
171
)
185
- checkpoint_dict = teacher_checkpointer . load_checkpoint ()
172
+ checkpoint_dict = teacher_checkpoint_client . load_base_checkpoint ()
186
173
return checkpoint_dict
187
174
188
175
def _update_recipe_state (self , ckpt_dict : Dict [str , Any ]) -> None :
@@ -237,7 +224,7 @@ def setup(self, cfg: DictConfig) -> None:
237
224
self ._metric_logger .log_config (cfg )
238
225
239
226
self ._compile = cfg .get ("compile" , False )
240
- checkpoint_dict = self .load_checkpoint ( cfg_checkpointer = cfg . checkpointer )
227
+ checkpoint_dict = self ._checkpoint_client . load_base_checkpoint ( )
241
228
teacher_checkpoint_dict = self .load_teacher_checkpoint (
242
229
cfg_checkpointer = cfg .teacher_checkpointer
243
230
)
@@ -251,7 +238,7 @@ def setup(self, cfg: DictConfig) -> None:
251
238
base_model_state_dict = checkpoint_dict [training .MODEL_KEY ],
252
239
lora_weights_state_dict = (
253
240
checkpoint_dict [training .ADAPTER_KEY ]
254
- if self . _resume_from_checkpoint
241
+ if training . ADAPTER_KEY in checkpoint_dict
255
242
else None
256
243
),
257
244
)
@@ -271,11 +258,31 @@ def setup(self, cfg: DictConfig) -> None:
271
258
cfg_optimizer = cfg .optimizer ,
272
259
opt_state_dict = (
273
260
checkpoint_dict [training .OPT_KEY ]
274
- if self . _resume_from_checkpoint
261
+ if training . OPT_KEY in checkpoint_dict
275
262
else None
276
263
),
277
264
)
278
265
266
+ if self ._resume_from_checkpoint :
267
+ # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
268
+ # using the DistributedCheckpointer.
269
+ # Therefore the recipe needs to load the distributed checkpoint to restore the training
270
+ # progress.
271
+ if self ._enable_async_checkpointing :
272
+ checkpoint_dict = self ._checkpoint_client .load_distributed_checkpoint (
273
+ self ._model ,
274
+ self ._optimizer ,
275
+ self ._adapter_config ,
276
+ )
277
+
278
+ if training .ADAPTER_KEY not in checkpoint_dict :
279
+ raise ValueError (
280
+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
281
+ )
282
+
283
+ # Update the recipe state from the checkpoint state dict.
284
+ self ._update_recipe_state (checkpoint_dict )
285
+
279
286
# initialize loss
280
287
self ._loss_fn = config .instantiate (cfg .loss )
281
288
self ._kd_loss_fn = config .instantiate (cfg .kd_loss )
@@ -432,6 +439,16 @@ def _setup_model(
432
439
self ._lora_attn_modules = list (cfg_model .lora_attn_modules )
433
440
self ._apply_lora_to_mlp = cfg_model .apply_lora_to_mlp
434
441
self ._apply_lora_to_output = getattr (cfg_model , "apply_lora_to_output" , False )
442
+ self ._adapter_config = {
443
+ "r" : self ._lora_rank ,
444
+ "lora_alpha" : self ._lora_alpha ,
445
+ "target_modules" : get_lora_module_names (
446
+ self ._lora_attn_modules ,
447
+ self ._apply_lora_to_mlp ,
448
+ self ._apply_lora_to_output ,
449
+ ),
450
+ "peft_type" : "LORA" ,
451
+ }
435
452
436
453
utils .log_rank_zero (
437
454
log ,
@@ -697,87 +714,21 @@ def _setup_data(
697
714
return dataloader
698
715
699
716
def save_checkpoint (self , epoch : int ) -> None :
700
- """
701
- Checkpoint the state of the recipe. The constructed checkpoint state dict
702
- contains the following information:
703
- - Merged weights with key MODEL_KEY
704
- - Adapter weights with key ADAPTER_KEY
705
- - Relevant recipe state if training is not complete
706
- - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
707
-
708
- To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
709
- """
710
- # final dict passed onto the checkpointer
711
- checkpoint_dict = {}
712
-
713
- intermediate_checkpoint = epoch + 1 < self .total_epochs
714
- # To prevent GPU memory from spiking during checkpoint save,
715
- # we consolidate the full model and optim state dicts on CPU for rank 0
716
- cpu_state_dict = training .gather_cpu_state_dict (
717
- self ._model ,
718
- self ._is_rank_zero ,
719
- device = self ._device ,
717
+ self ._checkpoint_client .save_checkpoint (
718
+ model = self ._model ,
719
+ optimizer = self ._optimizer ,
720
+ training_progress = TrainingProgress (
721
+ seed = self .seed ,
722
+ epochs_run = self .epochs_run ,
723
+ total_epochs = self .total_epochs ,
724
+ max_steps_per_epoch = self .max_steps_per_epoch ,
725
+ dataloader_state_dict = self ._dataloader .state_dict (),
726
+ ),
727
+ epoch = epoch ,
728
+ adapter_config = self ._adapter_config .copy (),
729
+ adapter_only = self ._save_adapter_weights_only ,
720
730
)
721
731
722
- if intermediate_checkpoint :
723
- opt_state_dict = training .get_full_optimizer_state_dict (
724
- self ._model ,
725
- self ._optimizer ,
726
- self ._is_rank_zero ,
727
- device = self ._device ,
728
- )
729
- else :
730
- opt_state_dict = None
731
-
732
- # Now that we have the model and opt state dict, create the actual checkpoint
733
- # to be sent to the checkpointer and ultimately written to file
734
- if self ._is_rank_zero :
735
-
736
- # Filter out the adapter keys and weights from the model state dict. These will
737
- # be saved separately
738
- adapter_state_dict = get_adapter_state_dict (cpu_state_dict )
739
- checkpoint_dict .update ({training .ADAPTER_KEY : adapter_state_dict })
740
-
741
- # merge the adapter weights and base weights to create the model checkpoint
742
- merged_state_dict = get_merged_lora_ckpt (
743
- cpu_state_dict ,
744
- rank = self ._lora_rank ,
745
- alpha = self ._lora_alpha ,
746
- )
747
- checkpoint_dict .update ({training .MODEL_KEY : merged_state_dict })
748
-
749
- # if training is in-progress, checkpoint the optimizer state and recipe state
750
- # as well
751
- if intermediate_checkpoint :
752
- checkpoint_dict .update (
753
- {
754
- training .OPT_KEY : opt_state_dict ,
755
- training .SEED_KEY : self .seed ,
756
- training .EPOCHS_KEY : self .epochs_run ,
757
- training .TOTAL_EPOCHS_KEY : self .total_epochs ,
758
- training .MAX_STEPS_KEY : self .max_steps_per_epoch ,
759
- training .DATALOADER_KEY : self ._dataloader .state_dict (),
760
- }
761
- )
762
-
763
- adapter_config = {
764
- "r" : self ._lora_rank ,
765
- "lora_alpha" : self ._lora_alpha ,
766
- "target_modules" : get_lora_module_names (
767
- self ._lora_attn_modules ,
768
- self ._apply_lora_to_mlp ,
769
- self ._apply_lora_to_output ,
770
- ),
771
- "peft_type" : "LORA" ,
772
- }
773
- checkpoint_dict .update ({training .ADAPTER_CONFIG : adapter_config })
774
- self ._checkpointer .save_checkpoint (
775
- checkpoint_dict ,
776
- epoch = epoch ,
777
- intermediate_checkpoint = intermediate_checkpoint ,
778
- adapter_only = self ._save_adapter_weights_only ,
779
- )
780
-
781
732
def _loss_step (
782
733
self , batch : Dict [str , torch .Tensor ]
783
734
) -> (torch .Tensor , torch .Tensor ):
0 commit comments