-
Notifications
You must be signed in to change notification settings - Fork 379
Can we support outputting checkpoints directly in .pt format? #1177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
It is reasonable to remove FP8 subclass from the checkpointing. I'll submit a PR for this. I may need some help from AO team to discuss how to remove FP8 subclass. cc., @vkuzo @danielvegamyhre |
Mind me asking why you would like .pt as the output format? E.g. is it because some downstream workload has to consume .pt format but not DCP? |
Hi @tianyu-l, yes exactly. Our use case is interop with torchtune, which accepts .pt or .safetensors. Actually between these two, .safetensors will be more useful, but .pt is also fine. |
I believe we can support both formats. The issue is that how do we remove the FP8Tensor. |
sorry for late reply, catching up after my leave
is there a hook of some sort that torchtitan calls when saving a state dict? The logic could go there. |
@vkuzo I can draft a PR to enable saving the state_dict to |
Summary: Several users have been asking this feature: #1177 TODO: Remove fp8 subclass tensor TODO: Support HF format Test Plan: ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --parallelism.tensor_parallel_degree 4 --parallelism.enable_async_tensor_parallel --checkpoint.model_weights_only --checkpoint.unshard_weights --checkpoint.export_dtype="bfloat16" --training.steps=10 --checkpoint.enable_checkpoint ```
Summary: Several users have been asking this feature: #1177 TODO: Remove fp8 subclass tensor TODO: Support HF format Test Plan: ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --parallelism.tensor_parallel_degree 4 --parallelism.enable_async_tensor_parallel --checkpoint.model_weights_only --checkpoint.unshard_weights --checkpoint.export_dtype="bfloat16" --training.steps=10 --checkpoint.enable_checkpoint ```
@vkuzo, @danielvegamyhre, @andrewor14 Please see the TODO in code of #1219. We just need to convert the FP8 tensor to the regular tensor in the |
Uh oh!
There was an error while loading. Please reload this page.
Today we need to do an extra conversion step according to this README: https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md
I think we should provide an option for users to specify which format to output their checkpoints instead, and call this function in torchtitan for users as part of outputting the checkpoint.
Bonus: This conversion step actually fails today if we used FP8 training. I had to manually add the following line to the
dcp_to_torch
function as a hack to get it to work:It would be great if we can just either implicitly add the safe globals when we output the checkpoint in torchtitan, or simply remove this
WeightWithDynamicFloat8CastTensor
from the BC surface.The text was updated successfully, but these errors were encountered: