Skip to content

Can we support outputting checkpoints directly in .pt format? #1177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
andrewor14 opened this issue May 9, 2025 · 8 comments
Open

Can we support outputting checkpoints directly in .pt format? #1177

andrewor14 opened this issue May 9, 2025 · 8 comments
Assignees
Labels
enhancement New feature or request module: checkpoint

Comments

@andrewor14
Copy link

andrewor14 commented May 9, 2025

Today we need to do an extra conversion step according to this README: https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md

python -m torch.distributed.checkpoint.format_utils dcp_to_torch outputs/checkpoint/step-100 /tmp/checkpoint.pt

I think we should provide an option for users to specify which format to output their checkpoints instead, and call this function in torchtitan for users as part of outputting the checkpoint.


Bonus: This conversion step actually fails today if we used FP8 training. I had to manually add the following line to the dcp_to_torch function as a hack to get it to work:

torch.serialization.add_safe_globals([torchao.float8.fsdp_utils.WeightWithDynamicFloat8CastTensor])

It would be great if we can just either implicitly add the safe globals when we output the checkpoint in torchtitan, or simply remove this WeightWithDynamicFloat8CastTensor from the BC surface.

@andrewor14
Copy link
Author

cc @vkuzo @fegin @wz337

@fegin
Copy link
Contributor

fegin commented May 9, 2025

It is reasonable to remove FP8 subclass from the checkpointing. I'll submit a PR for this. I may need some help from AO team to discuss how to remove FP8 subclass. cc., @vkuzo @danielvegamyhre

@fegin fegin self-assigned this May 9, 2025
@fegin fegin added enhancement New feature or request module: checkpoint labels May 9, 2025
@tianyu-l
Copy link
Contributor

Mind me asking why you would like .pt as the output format? E.g. is it because some downstream workload has to consume .pt format but not DCP?

@andrewor14
Copy link
Author

Hi @tianyu-l, yes exactly. Our use case is interop with torchtune, which accepts .pt or .safetensors. Actually between these two, .safetensors will be more useful, but .pt is also fine.

@fegin
Copy link
Contributor

fegin commented May 13, 2025

I believe we can support both formats. The issue is that how do we remove the FP8Tensor.

@vkuzo
Copy link
Contributor

vkuzo commented May 23, 2025

sorry for late reply, catching up after my leave

how do we remove the FP8Tensor.

is there a hook of some sort that torchtitan calls when saving a state dict? The logic could go there.

@fegin
Copy link
Contributor

fegin commented May 23, 2025

@vkuzo I can draft a PR to enable saving the state_dict to .pt. We don't need a hook for that. We just always convert the FP8 to the dtype users prefer.

fegin added a commit that referenced this issue May 23, 2025
Summary:
Several users have been asking this feature:
#1177

TODO: Remove fp8 subclass tensor
TODO: Support HF format

Test Plan:
```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.compile --parallelism.tensor_parallel_degree 4 --parallelism.enable_async_tensor_parallel --checkpoint.model_weights_only --checkpoint.unshard_weights --checkpoint.export_dtype="bfloat16" --training.steps=10 --checkpoint.enable_checkpoint
```
fegin added a commit that referenced this issue May 23, 2025
Summary:
Several users have been asking this feature:
#1177

TODO: Remove fp8 subclass tensor
TODO: Support HF format

Test Plan:
```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.compile --parallelism.tensor_parallel_degree 4 --parallelism.enable_async_tensor_parallel --checkpoint.model_weights_only --checkpoint.unshard_weights --checkpoint.export_dtype="bfloat16" --training.steps=10 --checkpoint.enable_checkpoint
```
@fegin
Copy link
Contributor

fegin commented May 23, 2025

@vkuzo, @danielvegamyhre, @andrewor14 Please see the TODO in code of #1219. We just need to convert the FP8 tensor to the regular tensor in the _export_weights().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request module: checkpoint
Projects
None yet
Development

No branches or pull requests

4 participants