Skip to content

Commit c260cd3

Browse files
committed
Implement the feature to save unsharded weights at the last step
Summary: Several users have been asking this feature: #1177 TODO: Remove fp8 subclass tensor TODO: Support HF format Test Plan: ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --parallelism.tensor_parallel_degree 4 --parallelism.enable_async_tensor_parallel --checkpoint.model_weights_only --checkpoint.unshard_weights --checkpoint.export_dtype="bfloat16" --training.steps=10 --checkpoint.enable_checkpoint ```
1 parent c1e796b commit c260cd3

File tree

2 files changed

+75
-24
lines changed

2 files changed

+75
-24
lines changed

torchtitan/components/checkpoint.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import enum
89
import functools
910
import os
@@ -12,7 +13,7 @@
1213
import shutil
1314
import threading
1415
import time
15-
from typing import Any
16+
from typing import Any, Generator
1617

1718
import torch
1819
import torch.distributed as dist
@@ -55,9 +56,24 @@ def __init__(self, model: nn.Module | list[nn.Module]) -> None:
5556
self.cache_state_dict = {
5657
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
5758
}
59+
self._unshard_state_dict = False
60+
61+
@contextlib.contextmanager
62+
def unshard_state_dict(self) -> Generator[None, None, None]:
63+
self._unshard_state_dict = True
64+
try:
65+
yield
66+
finally:
67+
self._unshard_state_dict = False
5868

5969
def state_dict(self) -> dict[str, Any]:
60-
return self.cache_state_dict
70+
if self._unshard_state_dict:
71+
func = functools.partial(
72+
get_model_state_dict, options=StateDictOptions(full_state_dict=True)
73+
)
74+
return {k: v for sd in map(func, self.model) for k, v in sd.items()}
75+
else:
76+
return self.cache_state_dict
6177

6278
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
6379
func = functools.partial(
@@ -288,6 +304,11 @@ def load_state_dict(state_dict):
288304
self.purge_thread = None
289305

290306
self.model_weights_only = ckpt_config.model_weights_only
307+
self.unshard_weights = ckpt_config.unshard_weights
308+
if self.unshard_weights and not self.model_weights_only:
309+
raise ValueError(
310+
"unshard_weights is only supported for model_weights_only=True"
311+
)
291312
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
292313
self.exclude_from_loading = ckpt_config.exclude_from_loading
293314

@@ -405,6 +426,7 @@ def load(self, step: int = -1) -> bool:
405426
bool: Whether the checkpoint was loaded successfully.
406427
"""
407428

429+
# TODO: add support for loading the checkpoint from the HF checkpoint.
408430
if self.ft_manager:
409431
self._ft_load()
410432

@@ -545,37 +567,55 @@ def _states_to_load(self, step: int) -> dict[str, Any]:
545567
states_to_load.pop(DATALOADER)
546568
return states_to_load
547569

570+
def _export_weights(self, curr_step: int) -> None:
571+
# We update self.states to keep the model only.
572+
# After this update, self.states = {
573+
# 'tok_embeddings.weight':...,
574+
# 'layers.0.attention.wq.weight': ...
575+
# }.
576+
context = (
577+
self.states[MODEL].unshard_state_dict()
578+
if self.unshard_weights
579+
else contextlib.nullcontext()
580+
)
581+
with context:
582+
self.states = self.states[MODEL].state_dict()
583+
584+
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
585+
# temporarily and we don't want to include it in the exported state_dict.
586+
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
587+
self.states.pop("freqs_cis", None)
588+
589+
if self.export_dtype != torch.float32:
590+
# TODO: Ensure FP8 tensor is converted back to the plain torch.Tensor.
591+
self.states = {k: v.to(self.export_dtype) for k, v in self.states.items()}
592+
logger.info(
593+
f"Saving a model weights only checkpoint in {self.export_dtype} "
594+
f"at last step, step {curr_step}."
595+
)
596+
checkpoint_id = self._create_checkpoint_id(curr_step)
597+
if self.unshard_weights:
598+
# TODO: support HF format
599+
os.makedirs(checkpoint_id, exist_ok=True)
600+
torch.save(self.states, os.path.join(checkpoint_id, "model_weights.pt"))
601+
else:
602+
save_with_gc(
603+
self.states, checkpoint_id=self._create_checkpoint_id(curr_step)
604+
)
605+
548606
def _save_last_step(self, curr_step: int) -> None:
549607
# We only consider saving weights only at the end of the training. So
550608
# this won't affect preemption and training resume. We also only allow
551609
# dtype conversion when we are checkpoint model weights only and the
552610
# current dtype is not the same as the export dtype at the end of the training.
553611

554612
if self.model_weights_only:
555-
# We update self.states to keep the model only.
556-
# After this update, self.states = {
557-
# 'tok_embeddings.weight':...,
558-
# 'layers.0.attention.wq.weight': ...
559-
# }.
560-
self.states = self.states[MODEL].state_dict()
561-
562-
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
563-
# temporarily and we don't want to include it in the exported state_dict.
564-
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
565-
self.states.pop("freqs_cis", None)
566-
567-
if self.export_dtype != torch.float32:
568-
self.states = {
569-
k: v.to(self.export_dtype) for k, v in self.states.items()
570-
}
571-
logger.info(
572-
f"Saving a model weights only checkpoint in {self.export_dtype} "
573-
f"at last step, step {curr_step}."
574-
)
613+
self._export_weights(curr_step)
575614
else:
576615
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
577-
578-
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
616+
save_with_gc(
617+
self.states, checkpoint_id=self._create_checkpoint_id(curr_step)
618+
)
579619

580620
def _should_save(self, curr_step: int, force: bool = False) -> bool:
581621
if not self.enable_checkpoint:

torchtitan/config_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,17 @@ class Checkpoint:
365365
The default value is false.
366366
"""
367367

368+
unshard_weights: bool = False
369+
"""
370+
Whether to unshard the weights before saving the checkpoint. If the option is set to True,
371+
the weights will be unsharded as plain torch.Tensor before saving the checkpoint. Moreover,
372+
since the weights are unsharded (full weights), ``torch.save()`` will be used instead of
373+
``DCP.save()``. Note that only rank0 will save the checkpoint so the saving can be slow.
374+
This option can only be set when ``model_weights_only`` is set to True.
375+
If ``model_weights_only`` is set to False, an assertion error will be raised.
376+
The default value is False.
377+
"""
378+
368379
export_dtype: Literal["float16", "bfloat16", "float32"] = "float32"
369380
"""
370381
Converts to the specified precision when training completes and model_weights_only=true.

0 commit comments

Comments
 (0)