Skip to content

Support TP for save_pretrained() #38111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/transformers/modeling_utils.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's likely a problem that this changes the state of the model and makes it invalid after calling save_pretrained.

The in memory model should ideally be the same before and after calling save_pretrained.

Original file line number Diff line number Diff line change
Expand Up @@ -3536,15 +3536,30 @@ def save_pretrained(
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
rank = int(os.getenv("RANK","-1"))
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
ptrs[id_tensor_storage(tensor)].append(name)
if isinstance(tensor, DTensor):
# When work under tensor parallelism, the DTensor should be restored to full tensor.
# Move the full tensor to 'cpu' since rank0 GPU memory might not large enough for large model.
# Smaller model might be fine to have a full copy on GPU, will optimize this next step.
tensor = tensor.full_tensor().to('cpu')
if rank <= 0:
state_dict[name] = tensor
ptrs[id_tensor_storage(tensor)].append(name)
else:
# If rank > 0, not needed, delete it to save memory
del tensor
else:
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

if rank >0:
del state_dict
return

# These are all the pointers of shared tensors
if hasattr(self, "hf_device_map"):
# if the model has offloaded parameters, we must check using find_tied_parameters()
Expand Down