Description
Feature request
This request proposes one of three changes (see Motivation for background, and Your contribution more thoughts on possible solutions) in order to allow saving of a certain class of models, including but not limited to Phi 3.5 Vision.
- Accept a
state_dict
argument in theTrainer
class'ssave_model()
method (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3719-L3768). Thisstate_dict
parameter should then be passed down to the call to the private_save()
method (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3842), which does accept astate_dict
argument. - Rather than
state_dict
as an argument tosave_model()
, determine the appropriate heuristic such that we can successfully save Phi 3.5 Vision and other architecturally similar models. - Some change to the way
transformers
handles shared tensors...?
Motivation
I encountered an issue while trying to fine-tune Phi 3.5 Vision using the Trainer
class from transformers
. In particular, when trying to call save()
or save_pretrained()
, transformers throws the following error:
RuntimeError: The weights trying to be saved contained shared tensors [{'model.vision_embed_tokens.wte.weight',
'model.embed_tokens.weight'}] that are mismatching the transformers base configuration.
Try saving using `safe_serialization=False` or remove this tensor sharing.
Below are two minimal reproducible examples:
Example #1
from transformers import AutoModelForCausalLM
model_id = "microsoft/Phi-3.5-vision-instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto"
)
model.save_pretrained("out")
Example #2
from transformers import (
Trainer,
TrainingArguments,
)
training_args = TrainingArguments(
save_only_model=True,
output_dir='./out/',
save_strategy='no',
)
trainer = Trainer(
model=model,
args=training_args
)
trainer.save_model()
It looks like others have also encountered this issue. See the list of reference issues below in "Issues".
A contributor to the Phi 3 Vision cookbook suggested the following solution, stating "You need to remove the wte weight. It's okay because when the model is loaded from the checkpoint, it will automatically copy the weight from the embedding weight."
state_dict = model.state_dict()
state_dict = {k:v for k, v in state_dict.items() if "wte" not in k}
model.save_pretrained(args.save_model_path, state_dict=state_dict, safe_serialization=True)
processor.save_pretrained(args.save_model_path)
This does indeed seem to work. However, it doesn't exactly fit into a use case that relies on the Trainer
abstraction. The call to the Trainer
class's save_model()
method doesn't accommodate a state_dict argument (see https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3719-L3768).
Issues
- RuntimeError: The weights trying to be saved contained shared tensors [{'model.vision_embed_tokens.wte.weight', 'model.embed_tokens.weight'}] kazuar/Phi3-Vision-ft#2
- https://discuss.huggingface.co/t/runtimeerror-when-saving-phi-3-5-vision-due-to-shared-tensors/116457
- Saving Phi 3 vision fails due to tensor sharing #32354
- https://discuss.huggingface.co/t/using-trainer-to-save-a-bartforsequenceclassification-model/81606
Your contribution
I'd be glad to submit a PR, but I think some discussion is needed from the appropriate transformers
stakeholders.
It's not clear to me whether the most appropriate change here is to modify the function signature.
Alternatively, maybe there's a heuristic by which we could determine whether the architecture is such that one needs to save everything but the wte
weights. I don't know the answer to that off-hand. It may require a deep dive from Phi 3/3.5 Vision SMEs.
Or more broadly, perhaps there's some change to the way transformers
handles shared tensors in the base configuration that would be most appropriate.