-
Notifications
You must be signed in to change notification settings - Fork 29k
Feat: save_pretrained for tensor parallel (and other parallelisms) models #37919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the initial work!
src/transformers/modeling_utils.py
Outdated
@@ -3491,7 +3491,7 @@ def save_pretrained( | |||
for name, tensor in state_dict.items(): | |||
# Sometimes in the state_dict we have non-tensor objects. | |||
# e.g. in bitsandbytes we have some `str` objects in the state_dict | |||
if isinstance(tensor, torch.Tensor): | |||
if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.distributed.tensor.DTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not all versions of torch have DTensor we need to protect this a tad bit
src/transformers/modeling_utils.py
Outdated
@@ -3601,7 +3601,10 @@ def save_pretrained( | |||
for shard_file, tensors in filename_to_tensors: | |||
shard = {} | |||
for tensor in tensors: | |||
shard[tensor] = state_dict[tensor].contiguous() | |||
if isinstance(state_dict[tensor], torch.distributed.tensor.DTensor): | |||
shard[tensor] = state_dict[tensor].full_tensor().contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine!
Wondering if we cannot also delete the tensor from the model once saved? TLDR we want the most efficient way to make sure we clean the model while saving.
Also we should check if the tensor is replicated
or not, if so we don't need to get the full_tensor!
Moreover, for local
plans, we need to manually gather, as the tensors are not DTensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re 1: we probably can do that, let me see if we get any meaningful savings from it.
Re 2: full_tensor
on placements=Replicate()
is a no-op (returns itself), except of some if/else checks in torch source, so I'm pretty sure there's no need to do the checks ourselves for the sake of readability. Relevant src here
Re 3: Sounds good, let me take a look at that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re 2 this is recent it seems! 3 weeks ago! wanna make sure all versions have it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That commit seems like a minor change that triggered it, commit moving DTensor
to public API (which is probably the oldest one we support anyway) already has it: here
full_tensor = state_dict[tensor].full_tensor() | ||
# to get the correctly ordered tensor we need to repack if packed | ||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): | ||
full_tensor = repack_weights(full_tensor, -1, 4, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self: replace the hardcoded 4
with world_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's make sure the TP api is in a single file!
# TODO: optimize this to avoid iterating over all | ||
for key, value in state_dict.items(): | ||
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor): | ||
state_dict[key] = convert_local_tensor_to_dtensor(value, key, self._device_mesh, self._tp_plan) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we are gonna iterate over all the weights when saving anyways not sure we need it
num_blocks: int = 2, | ||
) -> torch.Tensor: | ||
""" | ||
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice this is mega useful!
Save_pretrained that works on models that tensor parallelism was applied to. Works both on local tensors and Dtensors.
Memory snapshot dump timeline,
meta-llama/Meta-Llama-3-8B-Instruct
in fp32 on 2 GPUs. The spikes on top represent the saving, probably can't get any better than that. We can maybe warn users that they can specify a smaller shard size to avoid memory spikes.Relies on a small fix in Huggingface_hub: huggingface/huggingface_hub#3042
EDIT: now also supports
local_*
tp plans. Was tested on saving full llama4 model and comparing correctness (not added into tests as the model is huge, will probably create tests for this when we support user defined tp_plans)FIXES: #36436
TODO: This relies on PR in huggingface_hub, so it needs a version check there too?