Skip to content

[WIP] Implement the feature to save unsharded weights at the last step #1219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import enum
import functools
import os
Expand All @@ -12,7 +13,7 @@
import shutil
import threading
import time
from typing import Any
from typing import Any, Generator

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -55,9 +56,24 @@ def __init__(self, model: nn.Module | list[nn.Module]) -> None:
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
}
self._unshard_state_dict = False

@contextlib.contextmanager
def unshard_state_dict(self) -> Generator[None, None, None]:
self._unshard_state_dict = True
try:
yield
finally:
self._unshard_state_dict = False

def state_dict(self) -> dict[str, Any]:
return self.cache_state_dict
if self._unshard_state_dict:
func = functools.partial(
get_model_state_dict, options=StateDictOptions(full_state_dict=True)
)
return {k: v for sd in map(func, self.model) for k, v in sd.items()}
else:
return self.cache_state_dict

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
func = functools.partial(
Expand Down Expand Up @@ -288,6 +304,11 @@ def load_state_dict(state_dict):
self.purge_thread = None

self.model_weights_only = ckpt_config.model_weights_only
self.unshard_weights = ckpt_config.unshard_weights
if self.unshard_weights and not self.model_weights_only:
raise ValueError(
"unshard_weights is only supported for model_weights_only=True"
)
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = ckpt_config.exclude_from_loading

Expand Down Expand Up @@ -405,6 +426,7 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

# TODO: add support for loading the checkpoint from the HF checkpoint.
if self.ft_manager:
self._ft_load()

Expand Down Expand Up @@ -545,37 +567,55 @@ def _states_to_load(self, step: int) -> dict[str, Any]:
states_to_load.pop(DATALOADER)
return states_to_load

def _export_weights(self, curr_step: int) -> None:
# We update self.states to keep the model only.
# After this update, self.states = {
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.
context = (
self.states[MODEL].unshard_state_dict()
if self.unshard_weights
else contextlib.nullcontext()
)
with context:
self.states = self.states[MODEL].state_dict()

# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
self.states.pop("freqs_cis", None)

if self.export_dtype != torch.float32:
# TODO: Ensure FP8 tensor is converted back to the plain torch.Tensor.
self.states = {k: v.to(self.export_dtype) for k, v in self.states.items()}
logger.info(
f"Saving a model weights only checkpoint in {self.export_dtype} "
f"at last step, step {curr_step}."
)
checkpoint_id = self._create_checkpoint_id(curr_step)
if self.unshard_weights:
# TODO: support HF format
os.makedirs(checkpoint_id, exist_ok=True)
torch.save(self.states, os.path.join(checkpoint_id, "model_weights.pt"))
else:
save_with_gc(
self.states, checkpoint_id=self._create_checkpoint_id(curr_step)
)

def _save_last_step(self, curr_step: int) -> None:
# We only consider saving weights only at the end of the training. So
# this won't affect preemption and training resume. We also only allow
# dtype conversion when we are checkpoint model weights only and the
# current dtype is not the same as the export dtype at the end of the training.

if self.model_weights_only:
# We update self.states to keep the model only.
# After this update, self.states = {
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.
self.states = self.states[MODEL].state_dict()

# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
self.states.pop("freqs_cis", None)

if self.export_dtype != torch.float32:
self.states = {
k: v.to(self.export_dtype) for k, v in self.states.items()
}
logger.info(
f"Saving a model weights only checkpoint in {self.export_dtype} "
f"at last step, step {curr_step}."
)
self._export_weights(curr_step)
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")

save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
save_with_gc(
self.states, checkpoint_id=self._create_checkpoint_id(curr_step)
)

def _should_save(self, curr_step: int, force: bool = False) -> bool:
if not self.enable_checkpoint:
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,17 @@ class Checkpoint:
The default value is false.
"""

unshard_weights: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is WIP, just random thoughts.
unshard_weights doesn't seem to reflect format change -- e.g. I would have thought that unshard_weights means replicated DTensor but still in DCP format

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So format is independent from the tensors being saved. We can call it save_with_pytorch_tensor or save_with_plain_tensor if that makes more sense to users. Both DCP and HF supports saving plain pytorch tensors.

"""
Whether to unshard the weights before saving the checkpoint. If the option is set to True,
the weights will be unsharded as plain torch.Tensor before saving the checkpoint. Moreover,
since the weights are unsharded (full weights), ``torch.save()`` will be used instead of
``DCP.save()``. Note that only rank0 will save the checkpoint so the saving can be slow.
This option can only be set when ``model_weights_only`` is set to True.
If ``model_weights_only`` is set to False, an assertion error will be raised.
The default value is False.
"""

export_dtype: Literal["float16", "bfloat16", "float32"] = "float32"
"""
Converts to the specified precision when training completes and model_weights_only=true.
Expand Down
Loading