Skip to content

Feat: save_pretrained for tensor parallel (and other parallelisms) models #37919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

S1ro1
Copy link
Member

@S1ro1 S1ro1 commented May 1, 2025

Save_pretrained that works on models that tensor parallelism was applied to. Works both on local tensors and Dtensors.

image

Memory snapshot dump timeline, meta-llama/Meta-Llama-3-8B-Instruct in fp32 on 2 GPUs. The spikes on top represent the saving, probably can't get any better than that. We can maybe warn users that they can specify a smaller shard size to avoid memory spikes.

Relies on a small fix in Huggingface_hub: huggingface/huggingface_hub#3042

EDIT: now also supports local_* tp plans. Was tested on saving full llama4 model and comparing correctness (not added into tests as the model is huge, will probably create tests for this when we support user defined tp_plans)

FIXES: #36436

TODO: This relies on PR in huggingface_hub, so it needs a version check there too?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@S1ro1 S1ro1 requested a review from ArthurZucker May 2, 2025 11:19
@S1ro1 S1ro1 marked this pull request as ready for review May 2, 2025 11:49
@S1ro1 S1ro1 changed the title tmp: initial save pretrained with dtensors Feat: save_pretrained for model sharded with DTensors May 2, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the initial work!

@@ -3491,7 +3491,7 @@ def save_pretrained(
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.distributed.tensor.DTensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all versions of torch have DTensor we need to protect this a tad bit

@@ -3601,7 +3601,10 @@ def save_pretrained(
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
if isinstance(state_dict[tensor], torch.distributed.tensor.DTensor):
shard[tensor] = state_dict[tensor].full_tensor().contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine!
Wondering if we cannot also delete the tensor from the model once saved? TLDR we want the most efficient way to make sure we clean the model while saving.
Also we should check if the tensor is replicated or not, if so we don't need to get the full_tensor!

Moreover, for local plans, we need to manually gather, as the tensors are not DTensors

Copy link
Member Author

@S1ro1 S1ro1 May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re 1: we probably can do that, let me see if we get any meaningful savings from it.

Re 2: full_tensor on placements=Replicate() is a no-op (returns itself), except of some if/else checks in torch source, so I'm pretty sure there's no need to do the checks ourselves for the sake of readability. Relevant src here

Re 3: Sounds good, let me take a look at that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re 2 this is recent it seems! 3 weeks ago! wanna make sure all versions have it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That commit seems like a minor change that triggered it, commit moving DTensor to public API (which is probably the oldest one we support anyway) already has it: here

@S1ro1 S1ro1 changed the title Feat: save_pretrained for model sharded with DTensors Feat: save_pretrained for tensor parallel (and other parallelisms) models May 9, 2025
full_tensor = state_dict[tensor].full_tensor()
# to get the correctly ordered tensor we need to repack if packed
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
full_tensor = repack_weights(full_tensor, -1, 4, 2)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: replace the hardcoded 4 with world_size

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let's make sure the TP api is in a single file!

Comment on lines +4742 to +4745
# TODO: optimize this to avoid iterating over all
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
state_dict[key] = convert_local_tensor_to_dtensor(value, key, self._device_mesh, self._tp_plan)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we are gonna iterate over all the weights when saving anyways not sure we need it

num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this is mega useful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to save model after training with tensor parallel
3 participants