-
Notifications
You must be signed in to change notification settings - Fork 29.5k
Feat: save_pretrained for tensor parallel (and other parallelisms) models #37919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the initial work!
src/transformers/modeling_utils.py
Outdated
@@ -3491,7 +3491,7 @@ def save_pretrained( | |||
for name, tensor in state_dict.items(): | |||
# Sometimes in the state_dict we have non-tensor objects. | |||
# e.g. in bitsandbytes we have some `str` objects in the state_dict | |||
if isinstance(tensor, torch.Tensor): | |||
if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.distributed.tensor.DTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not all versions of torch have DTensor we need to protect this a tad bit
src/transformers/modeling_utils.py
Outdated
@@ -3601,7 +3601,10 @@ def save_pretrained( | |||
for shard_file, tensors in filename_to_tensors: | |||
shard = {} | |||
for tensor in tensors: | |||
shard[tensor] = state_dict[tensor].contiguous() | |||
if isinstance(state_dict[tensor], torch.distributed.tensor.DTensor): | |||
shard[tensor] = state_dict[tensor].full_tensor().contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine!
Wondering if we cannot also delete the tensor from the model once saved? TLDR we want the most efficient way to make sure we clean the model while saving.
Also we should check if the tensor is replicated
or not, if so we don't need to get the full_tensor!
Moreover, for local
plans, we need to manually gather, as the tensors are not DTensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re 1: we probably can do that, let me see if we get any meaningful savings from it.
Re 2: full_tensor
on placements=Replicate()
is a no-op (returns itself), except of some if/else checks in torch source, so I'm pretty sure there's no need to do the checks ourselves for the sake of readability. Relevant src here
Re 3: Sounds good, let me take a look at that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re 2 this is recent it seems! 3 weeks ago! wanna make sure all versions have it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That commit seems like a minor change that triggered it, commit moving DTensor
to public API (which is probably the oldest one we support anyway) already has it: here
src/transformers/modeling_utils.py
Outdated
full_tensor = state_dict[tensor].full_tensor() | ||
# to get the correctly ordered tensor we need to repack if packed | ||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): | ||
full_tensor = repack_weights(full_tensor, -1, 4, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self: replace the hardcoded 4
with world_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's make sure the TP api is in a single file!
src/transformers/modeling_utils.py
Outdated
# TODO: optimize this to avoid iterating over all | ||
for key, value in state_dict.items(): | ||
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor): | ||
state_dict[key] = convert_local_tensor_to_dtensor(value, key, self._device_mesh, self._tp_plan) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we are gonna iterate over all the weights when saving anyways not sure we need it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we need to do this before split_torch_state_dict_into_shards
is called, as that needs to have local tensors as dtensors to properly work. We iterate to save later.
num_blocks: int = 2, | ||
) -> torch.Tensor: | ||
""" | ||
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice this is mega useful!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 🤗 let's keep separating TP logic from core logic
src/transformers/modeling_utils.py
Outdated
def _replace_state_dict_local_with_dtensor(self, state_dict): | ||
""" | ||
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make saving possible. | ||
""" | ||
if self._tp_size is None: | ||
return state_dict | ||
# TODO: optimize this to avoid iterating over all | ||
for key, value in state_dict.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as much as possible should be hidden from this file and in TP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed and moved to tensor_parallel.py
…dels (huggingface#37919) * tmp: initial save pretrained with dtensors * Feat: add correctness tests * Refactor: version checks * Temp: 1:1 checkpoint llama4 * refactor * Tests * Feat: works * Style * Feat: version checks + minor fixes * Style * Fix: version checks in tests * Feat: move more stuff into tensor_parallel.py
…dels (huggingface#37919) * tmp: initial save pretrained with dtensors * Feat: add correctness tests * Refactor: version checks * Temp: 1:1 checkpoint llama4 * refactor * Tests * Feat: works * Style * Feat: version checks + minor fixes * Style * Fix: version checks in tests * Feat: move more stuff into tensor_parallel.py
Changes in transformers introduced 2 errors in the MacOS CI, which are handled in this PR. Context For context on why we use torch 2.2 for MacOS, check huggingface#2431. Unfortunately, as of today, the available GH workers for MacOS still haven't improved. Description The 1st error was introduced by huggingface/transformers#37785, which results in torch.load failing when using torch < 2.6. The 2nd error was introduced by huggingface/transformers#37919, which results in a DTensor import being triggered when calling save_pretrained, which fails with MacOS and torch 2.2 (possibly also later MacOS versions, I haven't checked). The proposed solution is to plug into pytest, intercept the test report, check for these specific errors, and turn them into skips. Alternative solutions The proposed solution is obviously an ugly hack. However, these are errors we cannot fix directly, as they're caused by a dependency and are caused by the old torch version we're forced to use (thus fixing them in transformers is probably not an option). Instead of altering the test report, the individual tests that fail could get an explicit skip marker when MacOS is detected. However, since the amount of affected tests are several hundreds, this is very impractical and leads to a lot of noise in the tests. Alternatively, we could move forward with the proposal in huggingface#2431 and remove MacOS completely from the CI. I do, however, still have the faint hope that GH will provide arm64 workers with more RAM in the future, allowing us to switch.
Hey, I just wanted to mention that on MacOS with torch 2.2, this causes a failure because Edit: Fixed by #38496. |
@S1ro1 Would you like to take a look on the failing tests https://github.com/huggingface/transformers/actions/runs/15407385951/job/43352806613 https://github.com/huggingface/transformers/actions/runs/15407385951/job/43352806705 including |
Save_pretrained that works on models that tensor parallelism was applied to. Works both on local tensors and Dtensors.
Memory snapshot dump timeline,
meta-llama/Meta-Llama-3-8B-Instruct
in fp32 on 2 GPUs. The spikes on top represent the saving, probably can't get any better than that. We can maybe warn users that they can specify a smaller shard size to avoid memory spikes.Relies on a small fix in Huggingface_hub: huggingface/huggingface_hub#3042
EDIT: now also supports
local_*
tp plans. Was tested on saving full llama4 model and comparing correctness (not added into tests as the model is huge, will probably create tests for this when we support user defined tp_plans)FIXES: #36436