Skip to content

Commit c40ec99

Browse files
committed
squashed commit
1 parent 4bc5af2 commit c40ec99

10 files changed

+688
-603
lines changed

recipes/knowledge_distillation_distributed.py

Lines changed: 61 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from warnings import warn
1313

1414
import torch
15-
from omegaconf import DictConfig, ListConfig
15+
from omegaconf import DictConfig, ListConfig, OmegaConf
1616

1717
from torch import nn
1818
from torch.distributed import destroy_process_group, init_process_group
@@ -26,14 +26,16 @@
2626
from torchtune.modules.peft import (
2727
AdapterModule,
2828
get_adapter_params,
29-
get_adapter_state_dict,
3029
get_lora_module_names,
31-
get_merged_lora_ckpt,
3230
set_trainable_params,
3331
validate_missing_and_unexpected_for_lora,
3432
)
3533
from torchtune.recipe_interfaces import FTRecipeInterface
3634
from torchtune.training import DummyProfiler, PROFILER_KEY
35+
from torchtune.training.checkpointing._checkpoint_client import (
36+
CheckpointClient,
37+
TrainingProgress,
38+
)
3739

3840
from tqdm import tqdm
3941

@@ -136,6 +138,9 @@ def __init__(self, cfg: DictConfig) -> None:
136138
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
137139

138140
# training attributes
141+
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
142+
self._checkpoint_client = CheckpointClient(cfg)
143+
139144
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
140145

141146
# These are public properties which are updated by the checkpoint loader
@@ -153,36 +158,18 @@ def __init__(self, cfg: DictConfig) -> None:
153158
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
154159
self._kd_ratio = cfg.get("kd_ratio", 0.5)
155160

156-
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
157-
"""
158-
Extract the checkpoint state from file and validate. This includes the
159-
base model weights. If resume_from_checkpoint is True, this also includes
160-
the adapter weights and recipe state
161-
"""
162-
self._checkpointer = config.instantiate(
163-
cfg_checkpointer,
164-
should_load_recipe_state=self._resume_from_checkpoint,
165-
)
166-
checkpoint_dict = self._checkpointer.load_checkpoint()
167-
168-
if self._resume_from_checkpoint:
169-
if training.ADAPTER_KEY not in checkpoint_dict:
170-
raise ValueError(
171-
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
172-
)
173-
# _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
174-
# no need to check here
175-
self._update_recipe_state(checkpoint_dict)
176-
return checkpoint_dict
177-
178161
def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
179162
"""
180163
Extract the teacher checkpoint state from file.
181164
"""
182-
teacher_checkpointer = config.instantiate(
183-
cfg_checkpointer,
165+
# add checkpointer class to config to work with checkpoint_client
166+
checkpointer_dict = {"checkpointer": cfg_checkpointer}
167+
168+
new_cfg_checkpointer = OmegaConf.create(checkpointer_dict)
169+
teacher_checkpoint_client = CheckpointClient(
170+
new_cfg_checkpointer,
184171
)
185-
checkpoint_dict = teacher_checkpointer.load_checkpoint()
172+
checkpoint_dict = teacher_checkpoint_client.load_base_checkpoint()
186173
return checkpoint_dict
187174

188175
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
@@ -237,7 +224,7 @@ def setup(self, cfg: DictConfig) -> None:
237224
self._metric_logger.log_config(cfg)
238225

239226
self._compile = cfg.get("compile", False)
240-
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
227+
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
241228
teacher_checkpoint_dict = self.load_teacher_checkpoint(
242229
cfg_checkpointer=cfg.teacher_checkpointer
243230
)
@@ -251,7 +238,7 @@ def setup(self, cfg: DictConfig) -> None:
251238
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
252239
lora_weights_state_dict=(
253240
checkpoint_dict[training.ADAPTER_KEY]
254-
if self._resume_from_checkpoint
241+
if training.ADAPTER_KEY in checkpoint_dict
255242
else None
256243
),
257244
)
@@ -271,11 +258,31 @@ def setup(self, cfg: DictConfig) -> None:
271258
cfg_optimizer=cfg.optimizer,
272259
opt_state_dict=(
273260
checkpoint_dict[training.OPT_KEY]
274-
if self._resume_from_checkpoint
261+
if training.OPT_KEY in checkpoint_dict
275262
else None
276263
),
277264
)
278265

266+
if self._resume_from_checkpoint:
267+
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
268+
# using the DistributedCheckpointer.
269+
# Therefore the recipe needs to load the distributed checkpoint to restore the training
270+
# progress.
271+
if self._enable_async_checkpointing:
272+
checkpoint_dict = self._checkpoint_client.load_distributed_checkpoint(
273+
self._model,
274+
self._optimizer,
275+
self._adapter_config,
276+
)
277+
278+
if training.ADAPTER_KEY not in checkpoint_dict:
279+
raise ValueError(
280+
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
281+
)
282+
283+
# Update the recipe state from the checkpoint state dict.
284+
self._update_recipe_state(checkpoint_dict)
285+
279286
# initialize loss
280287
self._loss_fn = config.instantiate(cfg.loss)
281288
self._kd_loss_fn = config.instantiate(cfg.kd_loss)
@@ -432,6 +439,16 @@ def _setup_model(
432439
self._lora_attn_modules = list(cfg_model.lora_attn_modules)
433440
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
434441
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
442+
self._adapter_config = {
443+
"r": self._lora_rank,
444+
"lora_alpha": self._lora_alpha,
445+
"target_modules": get_lora_module_names(
446+
self._lora_attn_modules,
447+
self._apply_lora_to_mlp,
448+
self._apply_lora_to_output,
449+
),
450+
"peft_type": "LORA",
451+
}
435452

436453
utils.log_rank_zero(
437454
log,
@@ -697,87 +714,21 @@ def _setup_data(
697714
return dataloader
698715

699716
def save_checkpoint(self, epoch: int) -> None:
700-
"""
701-
Checkpoint the state of the recipe. The constructed checkpoint state dict
702-
contains the following information:
703-
- Merged weights with key MODEL_KEY
704-
- Adapter weights with key ADAPTER_KEY
705-
- Relevant recipe state if training is not complete
706-
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
707-
708-
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
709-
"""
710-
# final dict passed onto the checkpointer
711-
checkpoint_dict = {}
712-
713-
intermediate_checkpoint = epoch + 1 < self.total_epochs
714-
# To prevent GPU memory from spiking during checkpoint save,
715-
# we consolidate the full model and optim state dicts on CPU for rank 0
716-
cpu_state_dict = training.gather_cpu_state_dict(
717-
self._model,
718-
self._is_rank_zero,
719-
device=self._device,
717+
self._checkpoint_client.save_checkpoint(
718+
model=self._model,
719+
optimizer=self._optimizer,
720+
training_progress=TrainingProgress(
721+
seed=self.seed,
722+
epochs_run=self.epochs_run,
723+
total_epochs=self.total_epochs,
724+
max_steps_per_epoch=self.max_steps_per_epoch,
725+
dataloader_state_dict=self._dataloader.state_dict(),
726+
),
727+
epoch=epoch,
728+
adapter_config=self._adapter_config.copy(),
729+
adapter_only=self._save_adapter_weights_only,
720730
)
721731

722-
if intermediate_checkpoint:
723-
opt_state_dict = training.get_full_optimizer_state_dict(
724-
self._model,
725-
self._optimizer,
726-
self._is_rank_zero,
727-
device=self._device,
728-
)
729-
else:
730-
opt_state_dict = None
731-
732-
# Now that we have the model and opt state dict, create the actual checkpoint
733-
# to be sent to the checkpointer and ultimately written to file
734-
if self._is_rank_zero:
735-
736-
# Filter out the adapter keys and weights from the model state dict. These will
737-
# be saved separately
738-
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
739-
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
740-
741-
# merge the adapter weights and base weights to create the model checkpoint
742-
merged_state_dict = get_merged_lora_ckpt(
743-
cpu_state_dict,
744-
rank=self._lora_rank,
745-
alpha=self._lora_alpha,
746-
)
747-
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
748-
749-
# if training is in-progress, checkpoint the optimizer state and recipe state
750-
# as well
751-
if intermediate_checkpoint:
752-
checkpoint_dict.update(
753-
{
754-
training.OPT_KEY: opt_state_dict,
755-
training.SEED_KEY: self.seed,
756-
training.EPOCHS_KEY: self.epochs_run,
757-
training.TOTAL_EPOCHS_KEY: self.total_epochs,
758-
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
759-
training.DATALOADER_KEY: self._dataloader.state_dict(),
760-
}
761-
)
762-
763-
adapter_config = {
764-
"r": self._lora_rank,
765-
"lora_alpha": self._lora_alpha,
766-
"target_modules": get_lora_module_names(
767-
self._lora_attn_modules,
768-
self._apply_lora_to_mlp,
769-
self._apply_lora_to_output,
770-
),
771-
"peft_type": "LORA",
772-
}
773-
checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config})
774-
self._checkpointer.save_checkpoint(
775-
checkpoint_dict,
776-
epoch=epoch,
777-
intermediate_checkpoint=intermediate_checkpoint,
778-
adapter_only=self._save_adapter_weights_only,
779-
)
780-
781732
def _loss_step(
782733
self, batch: Dict[str, torch.Tensor]
783734
) -> (torch.Tensor, torch.Tensor):

0 commit comments

Comments
 (0)