Skip to content

Commit 6b68abb

Browse files
authored
Add async checkpointing for kd distributed and single device recipes (#2726)
1 parent c8e670b commit 6b68abb

File tree

4 files changed

+259
-91
lines changed

4 files changed

+259
-91
lines changed

recipes/knowledge_distillation_single_device.py

+60-90
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515
import torchtune.modules.common_utils as common_utils
16-
from omegaconf import DictConfig, ListConfig
16+
from omegaconf import DictConfig, ListConfig, OmegaConf
1717

1818
from torch import nn
1919
from torch.optim import Optimizer
@@ -24,14 +24,16 @@
2424
from torchtune.datasets import ConcatDataset
2525
from torchtune.modules.peft import (
2626
get_adapter_params,
27-
get_adapter_state_dict,
2827
get_lora_module_names,
29-
get_merged_lora_ckpt,
3028
set_trainable_params,
3129
validate_missing_and_unexpected_for_lora,
3230
)
3331
from torchtune.recipe_interfaces import FTRecipeInterface
3432
from torchtune.training import DummyProfiler, PROFILER_KEY
33+
from torchtune.training.checkpointing._checkpoint_client import (
34+
CheckpointClient,
35+
TrainingProgress,
36+
)
3537
from tqdm import tqdm
3638

3739

@@ -140,36 +142,21 @@ def __init__(self, cfg: DictConfig) -> None:
140142
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
141143
self._kd_ratio = cfg.get("kd_ratio", 0.5)
142144

143-
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
144-
"""
145-
Extract the checkpoint state from file and validate. This includes the
146-
base model weights. If resume_from_checkpoint is True, this also includes
147-
the adapter weights and recipe state
148-
"""
149-
self._checkpointer = config.instantiate(
150-
cfg_checkpointer,
151-
should_load_recipe_state=self._resume_from_checkpoint,
152-
)
153-
checkpoint_dict = self._checkpointer.load_checkpoint()
154-
155-
if self._resume_from_checkpoint:
156-
if training.ADAPTER_KEY not in checkpoint_dict:
157-
raise ValueError(
158-
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
159-
)
160-
# _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
161-
# no need to check here
162-
self._update_recipe_state(checkpoint_dict)
163-
return checkpoint_dict
145+
self._checkpoint_client = CheckpointClient(cfg)
146+
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
164147

165148
def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
166149
"""
167150
Extract the teacher checkpoint state from file.
168151
"""
169-
teacher_checkpointer = config.instantiate(
170-
cfg_checkpointer,
152+
# add checkpointer class to config to work with checkpoint_client
153+
checkpointer_dict = {"checkpointer": cfg_checkpointer}
154+
155+
new_cfg_checkpointer = OmegaConf.create(checkpointer_dict)
156+
teacher_checkpoint_client = CheckpointClient(
157+
new_cfg_checkpointer,
171158
)
172-
checkpoint_dict = teacher_checkpointer.load_checkpoint()
159+
checkpoint_dict = teacher_checkpoint_client.load_base_checkpoint()
173160
return checkpoint_dict
174161

175162
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
@@ -220,7 +207,7 @@ def setup(self, cfg: DictConfig) -> None:
220207

221208
self._metric_logger = config.instantiate(cfg.metric_logger)
222209

223-
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
210+
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
224211

225212
# log config with parameter override
226213
self._metric_logger.log_config(cfg)
@@ -230,7 +217,7 @@ def setup(self, cfg: DictConfig) -> None:
230217
raise ValueError(
231218
"NPU does not support model compilation. Please set `compile: False` in the config."
232219
)
233-
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
220+
234221
teacher_checkpoint_dict = self.load_teacher_checkpoint(
235222
cfg_checkpointer=cfg.teacher_checkpointer
236223
)
@@ -245,7 +232,7 @@ def setup(self, cfg: DictConfig) -> None:
245232
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
246233
lora_weights_state_dict=(
247234
checkpoint_dict[training.ADAPTER_KEY]
248-
if self._resume_from_checkpoint
235+
if training.ADAPTER_KEY in checkpoint_dict
249236
else None
250237
),
251238
)
@@ -255,6 +242,27 @@ def setup(self, cfg: DictConfig) -> None:
255242
model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY],
256243
)
257244

245+
if self._resume_from_checkpoint:
246+
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
247+
# using the DistributedCheckpointer.
248+
# Therefore the recipe needs to load the distributed checkpoint to restore the training
249+
# progress.
250+
if self._enable_async_checkpointing:
251+
checkpoint_dict = self._checkpoint_client.load_distributed_checkpoint(
252+
self._model,
253+
self._optimizer,
254+
self._adapter_config,
255+
self._save_adapter_weights_only,
256+
)
257+
258+
if training.ADAPTER_KEY not in checkpoint_dict:
259+
raise ValueError(
260+
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
261+
)
262+
263+
# Update the recipe state from the checkpoint state dict.
264+
self._update_recipe_state(checkpoint_dict)
265+
258266
self._tokenizer = config.instantiate(cfg.tokenizer)
259267
self._logger.info("Tokenizer is initialized from file.")
260268

@@ -294,11 +302,6 @@ def setup(self, cfg: DictConfig) -> None:
294302
cfg_dataset=cfg.dataset,
295303
batch_size=cfg.batch_size,
296304
shuffle=cfg.shuffle,
297-
dataloader_state_dict=(
298-
ckpt_dict[training.DATALOADER_KEY]
299-
if self._resume_from_checkpoint
300-
else None
301-
),
302305
)
303306

304307
# Finally update the recipe state which can only be correctly set after all of the
@@ -378,6 +381,17 @@ def _setup_model(
378381
self._lora_attn_modules = list(cfg_model.lora_attn_modules)
379382
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
380383
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
384+
self._adapter_config = {
385+
"r": self._lora_rank,
386+
"lora_alpha": self._lora_alpha,
387+
"target_modules": get_lora_module_names(
388+
self._lora_attn_modules,
389+
self._apply_lora_to_mlp,
390+
self._apply_lora_to_output,
391+
),
392+
"peft_type": "LORA",
393+
}
394+
381395
self.adapter_params = get_adapter_params(model)
382396
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
383397
set_trainable_params(model, self.adapter_params)
@@ -492,7 +506,6 @@ def _setup_data(
492506
cfg_dataset: DictConfig,
493507
shuffle: bool,
494508
batch_size: int,
495-
dataloader_state_dict: Optional[Dict[str, Any]] = None,
496509
) -> StatefulDataLoader:
497510
"""
498511
All data related setup happens here. This recipe currently supports only
@@ -537,63 +550,20 @@ def _setup_data(
537550
return dataloader
538551

539552
def save_checkpoint(self, epoch: int) -> None:
540-
"""
541-
Checkpoint the state of the recipe. The constructed checkpoint state dict
542-
contains the following information:
543-
- Merged weights with key MODEL_KEY
544-
- Adapter weights with key ADAPTER_KEY
545-
- Relevant recipe state if training is not complete
546-
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
547-
548-
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
549-
"""
550-
ckpt_dict = {}
551-
552-
intermediate_checkpoint = epoch + 1 < self.total_epochs
553-
# if training is in-progress, checkpoint the optimizer state as well
554-
if intermediate_checkpoint:
555-
ckpt_dict.update(
556-
{
557-
training.OPT_KEY: self._optimizer.state_dict(),
558-
training.SEED_KEY: self.seed,
559-
training.EPOCHS_KEY: self.epochs_run,
560-
training.TOTAL_EPOCHS_KEY: self.total_epochs,
561-
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
562-
training.DATALOADER_KEY: self._dataloader.state_dict(),
563-
}
564-
)
565-
566-
# Move to CPU to avoid a copy on GPU
567-
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
568-
569-
# Construct the full state dict with LoRA weights merged into base LLM weights
570-
merged_state_dict = get_merged_lora_ckpt(
571-
state_dict,
572-
rank=self._lora_rank,
573-
alpha=self._lora_alpha,
574-
)
575-
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
576-
577-
# Construct the adapter weights
578-
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
579-
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
580-
adapter_config = {
581-
"r": self._lora_rank,
582-
"lora_alpha": self._lora_alpha,
583-
"target_modules": get_lora_module_names(
584-
self._lora_attn_modules,
585-
self._apply_lora_to_mlp,
586-
self._apply_lora_to_output,
553+
self._checkpoint_client.save_checkpoint(
554+
model=self._model,
555+
optimizer=self._optimizer,
556+
training_progress=TrainingProgress(
557+
seed=self.seed,
558+
epochs_run=self.epochs_run,
559+
total_epochs=self.total_epochs,
560+
max_steps_per_epoch=self.max_steps_per_epoch,
561+
dataloader_state_dict=self._dataloader.state_dict(),
587562
),
588-
"peft_type": "LORA",
589-
}
590-
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
591-
592-
self._checkpointer.save_checkpoint(
593-
ckpt_dict,
594563
epoch=epoch,
595-
intermediate_checkpoint=intermediate_checkpoint,
564+
adapter_config=self._adapter_config.copy(),
596565
adapter_only=self._save_adapter_weights_only,
566+
single_device=True,
597567
)
598568

599569
def _loss_step(

tests/recipes/test_knowledge_distillation_distributed.py

+92
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,98 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
201201
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
202202
)
203203

204+
@pytest.mark.integration_test
205+
@gpu_test(gpu_count=4)
206+
def test_training_state_on_resume_with_async_checkpointing(
207+
self, tmpdir, monkeypatch
208+
):
209+
"""Test whether the recipe state is correctly updated on resume. Since this
210+
is model agnostic, we should run this on the small model only. The test
211+
consists of three stages:
212+
- Train a model for 2 epochs
213+
- Resume training after epoch 1
214+
- Make sure final loss matches the expected value of a model successfully resumed from a ckpt
215+
"""
216+
217+
ckpt = "llama3_tune"
218+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
219+
ckpt_dir = ckpt_path.parent
220+
log_file = gen_log_file_name(tmpdir)
221+
tokenizer_path = Path(TOKENIZER_PATHS["llama3"])
222+
223+
# Config file needed for model conversion.
224+
# Create a second copy for training resume
225+
write_hf_ckpt_config(ckpt_dir)
226+
write_hf_ckpt_config(tmpdir)
227+
228+
# Train for two epochs
229+
cmd_1 = f"""
230+
tune run --nnodes 1 --nproc_per_node 4 knowledge_distillation_distributed \
231+
--config llama3_2/8B_to_1B_KD_lora_distributed \
232+
output_dir={tmpdir} \
233+
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
234+
checkpointer.checkpoint_dir='{ckpt_dir}' \
235+
checkpointer.checkpoint_files=[{ckpt_path}]\
236+
checkpointer.output_dir={tmpdir} \
237+
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
238+
teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \
239+
teacher_checkpointer.checkpoint_files=[{ckpt_path}] \
240+
teacher_checkpointer.output_dir={tmpdir} \
241+
enable_async_checkpointing=True \
242+
tokenizer.path={tokenizer_path} \
243+
tokenizer.prompt_template=null \
244+
""".split()
245+
246+
model_config = MODEL_TEST_CONFIGS["llama3_lora"]
247+
teacher_config = [
248+
"teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"]
249+
]
250+
251+
cmd_1 = (
252+
cmd_1 + self._get_test_config_overrides() + model_config + teacher_config
253+
)
254+
monkeypatch.setattr(sys, "argv", cmd_1)
255+
runpy.run_path(TUNE_PATH, run_name="__main__")
256+
257+
# Resume training
258+
cmd_2 = f"""
259+
tune run --nnodes 1 --nproc_per_node 4 knowledge_distillation_distributed \
260+
--config llama3_2/8B_to_1B_KD_lora_distributed \
261+
output_dir={tmpdir} \
262+
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
263+
checkpointer.checkpoint_dir={ckpt_dir} \
264+
checkpointer.checkpoint_files=[{ckpt_path}]\
265+
checkpointer.output_dir={tmpdir} \
266+
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
267+
teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \
268+
teacher_checkpointer.checkpoint_files=[{ckpt_path}] \
269+
teacher_checkpointer.output_dir={tmpdir} \
270+
resume_from_checkpoint=True \
271+
enable_async_checkpointing=True \
272+
metric_logger.filename={log_file} \
273+
tokenizer.path={tokenizer_path} \
274+
tokenizer.prompt_template=null \
275+
""".split()
276+
cmd_2 = (
277+
cmd_2
278+
+ self._get_test_config_overrides(epochs=3)
279+
+ model_config
280+
+ teacher_config
281+
)
282+
monkeypatch.setattr(sys, "argv", cmd_2)
283+
runpy.run_path(TUNE_PATH, run_name="__main__")
284+
285+
# Second epoch only
286+
expected_loss_values = self._fetch_expected_loss_values("llama3")[2:]
287+
loss_values = get_loss_values_from_metric_logger(log_file)
288+
# only take the first loss
289+
num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs
290+
loss_values = loss_values[0::num_losses][:2]
291+
292+
torch.testing.assert_close(
293+
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
294+
)
295+
204296
@pytest.mark.integration_test
205297
@gpu_test(gpu_count=4)
206298
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):

0 commit comments

Comments
 (0)