Skip to content

Changes required to save_model for certain models (e.g., Phi 3.5 Vision) #34690

Open
@jjbuck

Description

@jjbuck

Feature request

This request proposes one of three changes (see Motivation for background, and Your contribution more thoughts on possible solutions) in order to allow saving of a certain class of models, including but not limited to Phi 3.5 Vision.

  1. Accept a state_dict argument in the Trainer class's save_model() method (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3719-L3768). This state_dict parameter should then be passed down to the call to the private _save() method (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3842), which does accept a state_dict argument.
  2. Rather thanstate_dict as an argument to save_model(), determine the appropriate heuristic such that we can successfully save Phi 3.5 Vision and other architecturally similar models.
  3. Some change to the way transformers handles shared tensors...?

Motivation

I encountered an issue while trying to fine-tune Phi 3.5 Vision using the Trainer class from transformers. In particular, when trying to call save() or save_pretrained(), transformers throws the following error:

RuntimeError: The weights trying to be saved contained shared tensors [{'model.vision_embed_tokens.wte.weight', 
'model.embed_tokens.weight'}] that are mismatching the transformers base configuration.
Try saving using `safe_serialization=False` or remove this tensor sharing.

Below are two minimal reproducible examples:
Example #1

from transformers import AutoModelForCausalLM
model_id = "microsoft/Phi-3.5-vision-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto"
)
model.save_pretrained("out")

Example #2

from transformers import (
    Trainer,
    TrainingArguments,
)
training_args = TrainingArguments(
        save_only_model=True,
        output_dir='./out/',
        save_strategy='no',
    )
trainer = Trainer(
        model=model,
        args=training_args
    )
trainer.save_model()

It looks like others have also encountered this issue. See the list of reference issues below in "Issues".

A contributor to the Phi 3 Vision cookbook suggested the following solution, stating "You need to remove the wte weight. It's okay because when the model is loaded from the checkpoint, it will automatically copy the weight from the embedding weight."

state_dict = model.state_dict()
state_dict = {k:v for k, v in state_dict.items() if "wte" not in k}
model.save_pretrained(args.save_model_path, state_dict=state_dict, safe_serialization=True)
processor.save_pretrained(args.save_model_path)

This does indeed seem to work. However, it doesn't exactly fit into a use case that relies on the Trainer abstraction. The call to the Trainer class's save_model() method doesn't accommodate a state_dict argument (see https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3719-L3768).

Issues

  1. RuntimeError: The weights trying to be saved contained shared tensors [{'model.vision_embed_tokens.wte.weight', 'model.embed_tokens.weight'}] kazuar/Phi3-Vision-ft#2
  2. https://discuss.huggingface.co/t/runtimeerror-when-saving-phi-3-5-vision-due-to-shared-tensors/116457
  3. Saving Phi 3 vision fails due to tensor sharing #32354
  4. https://discuss.huggingface.co/t/using-trainer-to-save-a-bartforsequenceclassification-model/81606

Your contribution

I'd be glad to submit a PR, but I think some discussion is needed from the appropriate transformers stakeholders.

It's not clear to me whether the most appropriate change here is to modify the function signature.

Alternatively, maybe there's a heuristic by which we could determine whether the architecture is such that one needs to save everything but the wte weights. I don't know the answer to that off-hand. It may require a deep dive from Phi 3/3.5 Vision SMEs.

Or more broadly, perhaps there's some change to the way transformers handles shared tensors in the base configuration that would be most appropriate.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions