diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 3c424bf24..42497d807 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import enum import functools import os @@ -12,7 +13,7 @@ import shutil import threading import time -from typing import Any +from typing import Any, Generator import torch import torch.distributed as dist @@ -55,9 +56,24 @@ def __init__(self, model: nn.Module | list[nn.Module]) -> None: self.cache_state_dict = { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() } + self._unshard_state_dict = False + + @contextlib.contextmanager + def unshard_state_dict(self) -> Generator[None, None, None]: + self._unshard_state_dict = True + try: + yield + finally: + self._unshard_state_dict = False def state_dict(self) -> dict[str, Any]: - return self.cache_state_dict + if self._unshard_state_dict: + func = functools.partial( + get_model_state_dict, options=StateDictOptions(full_state_dict=True) + ) + return {k: v for sd in map(func, self.model) for k, v in sd.items()} + else: + return self.cache_state_dict def load_state_dict(self, state_dict: dict[str, Any]) -> None: func = functools.partial( @@ -288,6 +304,11 @@ def load_state_dict(state_dict): self.purge_thread = None self.model_weights_only = ckpt_config.model_weights_only + self.unshard_weights = ckpt_config.unshard_weights + if self.unshard_weights and not self.model_weights_only: + raise ValueError( + "unshard_weights is only supported for model_weights_only=True" + ) self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.exclude_from_loading = ckpt_config.exclude_from_loading @@ -405,6 +426,7 @@ def load(self, step: int = -1) -> bool: bool: Whether the checkpoint was loaded successfully. """ + # TODO: add support for loading the checkpoint from the HF checkpoint. if self.ft_manager: self._ft_load() @@ -545,6 +567,42 @@ def _states_to_load(self, step: int) -> dict[str, Any]: states_to_load.pop(DATALOADER) return states_to_load + def _export_weights(self, curr_step: int) -> None: + # We update self.states to keep the model only. + # After this update, self.states = { + # 'tok_embeddings.weight':..., + # 'layers.0.attention.wq.weight': ... + # }. + context = ( + self.states[MODEL].unshard_state_dict() + if self.unshard_weights + else contextlib.nullcontext() + ) + with context: + self.states = self.states[MODEL].state_dict() + + # For now, we will manually pop the freqs_cis buffer, as we made this permanent + # temporarily and we don't want to include it in the exported state_dict. + # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 + self.states.pop("freqs_cis", None) + + if self.export_dtype != torch.float32: + # TODO: Ensure FP8 tensor is converted back to the plain torch.Tensor. + self.states = {k: v.to(self.export_dtype) for k, v in self.states.items()} + logger.info( + f"Saving a model weights only checkpoint in {self.export_dtype} " + f"at last step, step {curr_step}." + ) + checkpoint_id = self._create_checkpoint_id(curr_step) + if self.unshard_weights: + # TODO: support HF format + os.makedirs(checkpoint_id, exist_ok=True) + torch.save(self.states, os.path.join(checkpoint_id, "model_weights.pt")) + else: + save_with_gc( + self.states, checkpoint_id=self._create_checkpoint_id(curr_step) + ) + def _save_last_step(self, curr_step: int) -> None: # We only consider saving weights only at the end of the training. So # this won't affect preemption and training resume. We also only allow @@ -552,30 +610,12 @@ def _save_last_step(self, curr_step: int) -> None: # current dtype is not the same as the export dtype at the end of the training. if self.model_weights_only: - # We update self.states to keep the model only. - # After this update, self.states = { - # 'tok_embeddings.weight':..., - # 'layers.0.attention.wq.weight': ... - # }. - self.states = self.states[MODEL].state_dict() - - # For now, we will manually pop the freqs_cis buffer, as we made this permanent - # temporarily and we don't want to include it in the exported state_dict. - # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 - self.states.pop("freqs_cis", None) - - if self.export_dtype != torch.float32: - self.states = { - k: v.to(self.export_dtype) for k, v in self.states.items() - } - logger.info( - f"Saving a model weights only checkpoint in {self.export_dtype} " - f"at last step, step {curr_step}." - ) + self._export_weights(curr_step) else: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") - - save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) + save_with_gc( + self.states, checkpoint_id=self._create_checkpoint_id(curr_step) + ) def _should_save(self, curr_step: int, force: bool = False) -> bool: if not self.enable_checkpoint: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 5a5fbb488..8fe78ff6d 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -365,6 +365,17 @@ class Checkpoint: The default value is false. """ + unshard_weights: bool = False + """ + Whether to unshard the weights before saving the checkpoint. If the option is set to True, + the weights will be unsharded as plain torch.Tensor before saving the checkpoint. Moreover, + since the weights are unsharded (full weights), ``torch.save()`` will be used instead of + ``DCP.save()``. Note that only rank0 will save the checkpoint so the saving can be slow. + This option can only be set when ``model_weights_only`` is set to True. + If ``model_weights_only`` is set to False, an assertion error will be raised. + The default value is False. + """ + export_dtype: Literal["float16", "bfloat16", "float32"] = "float32" """ Converts to the specified precision when training completes and model_weights_only=true.