|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import contextlib |
7 | 8 | import enum
|
8 | 9 | import functools
|
9 | 10 | import os
|
|
12 | 13 | import shutil
|
13 | 14 | import threading
|
14 | 15 | import time
|
15 |
| -from typing import Any |
| 16 | +from typing import Any, Generator |
16 | 17 |
|
17 | 18 | import torch
|
18 | 19 | import torch.distributed as dist
|
@@ -55,9 +56,24 @@ def __init__(self, model: nn.Module | list[nn.Module]) -> None:
|
55 | 56 | self.cache_state_dict = {
|
56 | 57 | k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
|
57 | 58 | }
|
| 59 | + self._unshard_state_dict = False |
| 60 | + |
| 61 | + @contextlib.contextmanager |
| 62 | + def unshard_state_dict(self) -> Generator[None, None, None]: |
| 63 | + self._unshard_state_dict = True |
| 64 | + try: |
| 65 | + yield |
| 66 | + finally: |
| 67 | + self._unshard_state_dict = False |
58 | 68 |
|
59 | 69 | def state_dict(self) -> dict[str, Any]:
|
60 |
| - return self.cache_state_dict |
| 70 | + if self._unshard_state_dict: |
| 71 | + func = functools.partial( |
| 72 | + get_model_state_dict, options=StateDictOptions(full_state_dict=True) |
| 73 | + ) |
| 74 | + return {k: v for sd in map(func, self.model) for k, v in sd.items()} |
| 75 | + else: |
| 76 | + return self.cache_state_dict |
61 | 77 |
|
62 | 78 | def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
63 | 79 | func = functools.partial(
|
@@ -288,6 +304,11 @@ def load_state_dict(state_dict):
|
288 | 304 | self.purge_thread = None
|
289 | 305 |
|
290 | 306 | self.model_weights_only = ckpt_config.model_weights_only
|
| 307 | + self.unshard_weights = ckpt_config.unshard_weights |
| 308 | + if self.unshard_weights and not self.model_weights_only: |
| 309 | + raise ValueError( |
| 310 | + "unshard_weights is only supported for model_weights_only=True" |
| 311 | + ) |
291 | 312 | self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
|
292 | 313 | self.exclude_from_loading = ckpt_config.exclude_from_loading
|
293 | 314 |
|
@@ -405,6 +426,7 @@ def load(self, step: int = -1) -> bool:
|
405 | 426 | bool: Whether the checkpoint was loaded successfully.
|
406 | 427 | """
|
407 | 428 |
|
| 429 | + # TODO: add support for loading the checkpoint from the HF checkpoint. |
408 | 430 | if self.ft_manager:
|
409 | 431 | self._ft_load()
|
410 | 432 |
|
@@ -545,37 +567,55 @@ def _states_to_load(self, step: int) -> dict[str, Any]:
|
545 | 567 | states_to_load.pop(DATALOADER)
|
546 | 568 | return states_to_load
|
547 | 569 |
|
| 570 | + def _export_weights(self, curr_step: int) -> None: |
| 571 | + # We update self.states to keep the model only. |
| 572 | + # After this update, self.states = { |
| 573 | + # 'tok_embeddings.weight':..., |
| 574 | + # 'layers.0.attention.wq.weight': ... |
| 575 | + # }. |
| 576 | + context = ( |
| 577 | + self.states[MODEL].unshard_state_dict() |
| 578 | + if self.unshard_weights |
| 579 | + else contextlib.nullcontext() |
| 580 | + ) |
| 581 | + with context: |
| 582 | + self.states = self.states[MODEL].state_dict() |
| 583 | + |
| 584 | + # For now, we will manually pop the freqs_cis buffer, as we made this permanent |
| 585 | + # temporarily and we don't want to include it in the exported state_dict. |
| 586 | + # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 |
| 587 | + self.states.pop("freqs_cis", None) |
| 588 | + |
| 589 | + if self.export_dtype != torch.float32: |
| 590 | + # TODO: Ensure FP8 tensor is converted back to the plain torch.Tensor. |
| 591 | + self.states = {k: v.to(self.export_dtype) for k, v in self.states.items()} |
| 592 | + logger.info( |
| 593 | + f"Saving a model weights only checkpoint in {self.export_dtype} " |
| 594 | + f"at last step, step {curr_step}." |
| 595 | + ) |
| 596 | + checkpoint_id = self._create_checkpoint_id(curr_step) |
| 597 | + if self.unshard_weights: |
| 598 | + # TODO: support HF format |
| 599 | + os.makedirs(checkpoint_id, exist_ok=True) |
| 600 | + torch.save(self.states, os.path.join(checkpoint_id, "model_weights.pt")) |
| 601 | + else: |
| 602 | + save_with_gc( |
| 603 | + self.states, checkpoint_id=self._create_checkpoint_id(curr_step) |
| 604 | + ) |
| 605 | + |
548 | 606 | def _save_last_step(self, curr_step: int) -> None:
|
549 | 607 | # We only consider saving weights only at the end of the training. So
|
550 | 608 | # this won't affect preemption and training resume. We also only allow
|
551 | 609 | # dtype conversion when we are checkpoint model weights only and the
|
552 | 610 | # current dtype is not the same as the export dtype at the end of the training.
|
553 | 611 |
|
554 | 612 | if self.model_weights_only:
|
555 |
| - # We update self.states to keep the model only. |
556 |
| - # After this update, self.states = { |
557 |
| - # 'tok_embeddings.weight':..., |
558 |
| - # 'layers.0.attention.wq.weight': ... |
559 |
| - # }. |
560 |
| - self.states = self.states[MODEL].state_dict() |
561 |
| - |
562 |
| - # For now, we will manually pop the freqs_cis buffer, as we made this permanent |
563 |
| - # temporarily and we don't want to include it in the exported state_dict. |
564 |
| - # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 |
565 |
| - self.states.pop("freqs_cis", None) |
566 |
| - |
567 |
| - if self.export_dtype != torch.float32: |
568 |
| - self.states = { |
569 |
| - k: v.to(self.export_dtype) for k, v in self.states.items() |
570 |
| - } |
571 |
| - logger.info( |
572 |
| - f"Saving a model weights only checkpoint in {self.export_dtype} " |
573 |
| - f"at last step, step {curr_step}." |
574 |
| - ) |
| 613 | + self._export_weights(curr_step) |
575 | 614 | else:
|
576 | 615 | logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
|
577 |
| - |
578 |
| - save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) |
| 616 | + save_with_gc( |
| 617 | + self.states, checkpoint_id=self._create_checkpoint_id(curr_step) |
| 618 | + ) |
579 | 619 |
|
580 | 620 | def _should_save(self, curr_step: int, force: bool = False) -> bool:
|
581 | 621 | if not self.enable_checkpoint:
|
|
0 commit comments