Skip to content

Feat: save_pretrained for tensor parallel (and other parallelisms) models #37919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 19, 2025

Conversation

S1ro1
Copy link
Member

@S1ro1 S1ro1 commented May 1, 2025

Save_pretrained that works on models that tensor parallelism was applied to. Works both on local tensors and Dtensors.

image

Memory snapshot dump timeline, meta-llama/Meta-Llama-3-8B-Instruct in fp32 on 2 GPUs. The spikes on top represent the saving, probably can't get any better than that. We can maybe warn users that they can specify a smaller shard size to avoid memory spikes.

Relies on a small fix in Huggingface_hub: huggingface/huggingface_hub#3042

EDIT: now also supports local_* tp plans. Was tested on saving full llama4 model and comparing correctness (not added into tests as the model is huge, will probably create tests for this when we support user defined tp_plans)

FIXES: #36436

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@S1ro1 S1ro1 requested a review from ArthurZucker May 2, 2025 11:19
@S1ro1 S1ro1 marked this pull request as ready for review May 2, 2025 11:49
@S1ro1 S1ro1 changed the title tmp: initial save pretrained with dtensors Feat: save_pretrained for model sharded with DTensors May 2, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the initial work!

@@ -3491,7 +3491,7 @@ def save_pretrained(
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.distributed.tensor.DTensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all versions of torch have DTensor we need to protect this a tad bit

@@ -3601,7 +3601,10 @@ def save_pretrained(
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
if isinstance(state_dict[tensor], torch.distributed.tensor.DTensor):
shard[tensor] = state_dict[tensor].full_tensor().contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine!
Wondering if we cannot also delete the tensor from the model once saved? TLDR we want the most efficient way to make sure we clean the model while saving.
Also we should check if the tensor is replicated or not, if so we don't need to get the full_tensor!

Moreover, for local plans, we need to manually gather, as the tensors are not DTensors

Copy link
Member Author

@S1ro1 S1ro1 May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re 1: we probably can do that, let me see if we get any meaningful savings from it.

Re 2: full_tensor on placements=Replicate() is a no-op (returns itself), except of some if/else checks in torch source, so I'm pretty sure there's no need to do the checks ourselves for the sake of readability. Relevant src here

Re 3: Sounds good, let me take a look at that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re 2 this is recent it seems! 3 weeks ago! wanna make sure all versions have it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That commit seems like a minor change that triggered it, commit moving DTensor to public API (which is probably the oldest one we support anyway) already has it: here

@S1ro1 S1ro1 changed the title Feat: save_pretrained for model sharded with DTensors Feat: save_pretrained for tensor parallel (and other parallelisms) models May 9, 2025
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, 4, 2)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: replace the hardcoded 4 with world_size

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let's make sure the TP api is in a single file!

Comment on lines 4742 to 4745
# TODO: optimize this to avoid iterating over all
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
state_dict[key] = convert_local_tensor_to_dtensor(value, key, self._device_mesh, self._tp_plan)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we are gonna iterate over all the weights when saving anyways not sure we need it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we need to do this before split_torch_state_dict_into_shards is called, as that needs to have local tensors as dtensors to properly work. We iterate to save later.

num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this is mega useful!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 🤗 let's keep separating TP logic from core logic

Comment on lines 4744 to 4751
def _replace_state_dict_local_with_dtensor(self, state_dict):
"""
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make saving possible.
"""
if self._tp_size is None:
return state_dict
# TODO: optimize this to avoid iterating over all
for key, value in state_dict.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as much as possible should be hidden from this file and in TP

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and moved to tensor_parallel.py

@S1ro1 S1ro1 enabled auto-merge (squash) May 19, 2025 17:49
@S1ro1 S1ro1 merged commit 46a4b7c into main May 19, 2025
21 checks passed
@S1ro1 S1ro1 deleted the save-pretrained-dtensor branch May 19, 2025 18:16
faaany pushed a commit to faaany/transformers that referenced this pull request May 21, 2025
…dels (huggingface#37919)

* tmp: initial save pretrained with dtensors

* Feat: add correctness tests

* Refactor: version checks

* Temp: 1:1 checkpoint llama4

* refactor

* Tests

* Feat: works

* Style

* Feat: version checks + minor fixes

* Style

* Fix: version checks in tests

* Feat: move more stuff into tensor_parallel.py
xvyv99 pushed a commit to xvyv99/transformers that referenced this pull request May 21, 2025
…dels (huggingface#37919)

* tmp: initial save pretrained with dtensors

* Feat: add correctness tests

* Refactor: version checks

* Temp: 1:1 checkpoint llama4

* refactor

* Tests

* Feat: works

* Style

* Feat: version checks + minor fixes

* Style

* Fix: version checks in tests

* Feat: move more stuff into tensor_parallel.py
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request May 30, 2025
Changes in transformers introduced 2 errors in the MacOS CI, which are
handled in this PR.

Context

For context on why we use torch 2.2 for MacOS, check huggingface#2431.
Unfortunately, as of today, the available GH workers for MacOS still
haven't improved.

Description

The 1st error was introduced by
huggingface/transformers#37785, which results in
torch.load failing when using torch < 2.6.

The 2nd error was introduced by
huggingface/transformers#37919, which results in
a DTensor import being triggered when calling save_pretrained, which
fails with MacOS and torch 2.2 (possibly also later MacOS versions, I
haven't checked).

The proposed solution is to plug into pytest, intercept the test report,
check for these specific errors, and turn them into skips.

Alternative solutions

The proposed solution is obviously an ugly hack. However, these are
errors we cannot fix directly, as they're caused by a dependency and are
caused by the old torch version we're forced to use (thus fixing them in
transformers is probably not an option).

Instead of altering the test report, the individual tests that fail
could get an explicit skip marker when MacOS is detected. However, since
the amount of affected tests are several hundreds, this is very
impractical and leads to a lot of noise in the tests.

Alternatively, we could move forward with the proposal in huggingface#2431 and
remove MacOS completely from the CI. I do, however, still have the faint
hope that GH will provide arm64 workers with more RAM in the future,
allowing us to switch.
@BenjaminBossan
Copy link
Member

BenjaminBossan commented May 30, 2025

Hey, I just wanted to mention that on MacOS with torch 2.2, this causes a failure because DTensor cannot be imported. I don't have a Mac to test this, so presumably this will also fail with other torch versions (I discovered this through the PEFT CI). Since the DTensor import is in the "happy path" of save_pretrained, this means that this method could potentially break for many MacOS users.

Edit: Fixed by #38496.

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 3, 2025

@S1ro1 Would you like to take a look on the failing tests

https://github.com/huggingface/transformers/actions/runs/15407385951/job/43352806613

https://github.com/huggingface/transformers/actions/runs/15407385951/job/43352806705

including tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel::test_model_save

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to save model after training with tensor parallel
5 participants