Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size Mismatch Error During LoRA Adapter Merge in Supervised Fine-Tuning of Llama 3.2 1B on AWS Trainium Instance #733

Open
3 of 4 tasks
Kelv1nYu opened this issue Nov 4, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@Kelv1nYu
Copy link

Kelv1nYu commented Nov 4, 2024

System Info

Platform:

- Platform: Linux-5.15.0-1031-aws-x86_64-with-glibc2.35
- Python version: 3.10.12
- AWS Instance Type: trn1.2xlarge
- AWS AMI: huggingface-neuron-2024-10-01T10-10-31Z-692efe1a-8d5c-4033-bcbc-5d99f2d4ae6a(ami-0271953de6aa28bdb)


Python packages:

- `optimum-neuron` version: 0.0.25
- `neuron-sdk` version: 2.20.0
- `optimum` version: 1.22.0
- `transformers` version: 4.43.2
- `huggingface_hub` version: 0.25.1
- `torch` version: 2.1.2+cu121
- `aws-neuronx-runtime-discovery` version: 2.9
- `libneuronxla` version: 2.0.4115.0
- `neuronx-cc` version: 2.15.128.0+56dc5a86
- `neuronx-distributed` version: 0.9.0
- `neuronx-hwm` version: NA
- `torch-neuronx` version: 2.1.2.2.3.0
- `torch-xla` version: 2.1.4
- `transformers-neuronx` version: 0.12.313


Neuron Driver:


WARNING: apt does not have a stable CLI interface. Use with caution in scripts.

aws-neuronx-collectives/unknown,now 2.22.26.0-17a033bc8 amd64 [installed]
aws-neuronx-dkms/unknown,now 2.18.12.0 amd64 [installed]
aws-neuronx-oci-hook/unknown,now 2.5.3.0 amd64 [installed]
aws-neuronx-runtime-lib/unknown,now 2.22.14.0-6e27b8d5b amd64 [installed]
aws-neuronx-tools/unknown,now 2.19.0.0 amd64 [installed]

Who can help?

@michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

I was following the tutorial for Supervised Fine-Tuning of Llama 3 8B on one AWS Trainium instance. Since I used an ml.trn1.2xlarge AWS instance, which only has two Neuron cores, I modified the parameters and model accordingly. Below is the training script I used with my dataset:

Training Code:

from dataclasses import dataclass, field

from datasets import load_dataset
from peft import LoraConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)

from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser
from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArguments
from optimum.neuron.distributed import lazy_load_for_parallelism


def format_medchat(examples):
    output_text = []
    for i in range(len(examples["question"])):
        question = f"### Question\n{examples['question'][i]}"
        answer = f"### Answer\n{examples['answer'][i]}"
        prompt = "\n\n".join([text for text in [question, answer]])
        output_text.append(prompt)
    return output_text

def training_function(script_args, training_args):
    dataset = load_dataset("ngram/medchat-qa", split="train")

    tokenizer = AutoTokenizer.from_pretrained(script_args.model_id)
    tokenizer.pad_token = tokenizer.eos_token

    with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
        model = AutoModelForCausalLM.from_pretrained(script_args.model_id)

    config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=[
            "q_proj",
            "gate_proj",
            "v_proj",
            "o_proj",
            "k_proj",
            "up_proj",
            "down_proj"
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )

    args = training_args.to_dict()
    sft_config = NeuronSFTConfig(
        max_seq_length=1024,
        packing=False,
        **args,
    )

    trainer = NeuronSFTTrainer(
        args=sft_config,
        model=model,
        peft_config=config,
        tokenizer=tokenizer,
        train_dataset=dataset,
        formatting_func=format_medchat,
    )

    # Start training
    trainer.train()

    trainer.save_model()  # Saves the tokenizer too for easy upload


@dataclass
class ScriptArguments:
    model_id: str = field(
        default="meta-llama/Meta-Llama-3-8B",
        metadata={"help": "The model that you want to train from the Hugging Face hub."},
    )


def main():
    parser = HfArgumentParser([ScriptArguments, NeuronTrainingArguments])
    script_args, training_args = parser.parse_args_into_dataclasses()

    # set seed
    set_seed(training_args.seed)

    # run training function
    training_function(script_args, training_args)


if __name__ == "__main__":
    main()

Training Shell Script:

#!/bin/bash

set -ex

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"

PROCESSES_PER_NODE=2

NUM_EPOCHS=1
TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=1
MODEL_NAME="meta-llama/Llama-3.2-1B"
OUTPUT_DIR=output

if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
	    MAX_STEPS=$((LOGGING_STEPS + 5))
    else
	        MAX_STEPS=-1
fi


XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE sft_lora_finetune_llm.py \
	  --model_id $MODEL_NAME \
	  --num_train_epochs $NUM_EPOCHS \
	  --do_train \
	  --learning_rate 5e-5 \
	  --warmup_ratio 0.03 \
	  --max_steps $MAX_STEPS \
	  --per_device_train_batch_size $BS \
	  --per_device_eval_batch_size $BS \
	  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
	  --gradient_checkpointing true \
	  --bf16 \
	  --zero_1 false \
	  --tensor_parallel_size $TP_DEGREE \
	  --pipeline_parallel_size $PP_DEGREE \
	  --logging_steps $LOGGING_STEPS \
	  --save_total_limit 1 \
	  --output_dir $OUTPUT_DIR \
	  --lr_scheduler_type "constant" \
	  --overwrite_output_dir

After training, the tutorial guided me to consolidate the model using optimum-cli and I consolidated the output folder using this command:

optimum-cli neuron consolidate output output

After the consolidation, when trying to compile the adapter model for inference, I encountered the following error:

ValueError: The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`.

So I attempted to merge the LoRA adapter model to the base model but received this error:

Traceback (most recent call last):
  File "/opt/ml/code/load_model.py", line 25, in <module>
model = PeftModel.from_pretrained(model, adapter_model_path)
  File "/usr/local/lib/python3.10/site-packages/peft/peft_model.py", line 586, in from_pretrained
model.load_adapter(
  File "/usr/local/lib/python3.10/site-packages/peft/peft_model.py", line 1181, in load_adapter
load_result = set_peft_model_state_dict(
  File "/usr/local/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 460, in set_peft_model_state_dict
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
#011size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([64128, 2048]) from checkpoint, the shape in current model is torch.Size([128256, 2048]).

Merge Code:

from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from peft import PeftModel
import torch

hf_token = "xxx"
login(token=hf_token)

base_model_path = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.bfloat16)

adapter_model_path = "output"
tokenizer = AutoTokenizer.from_pretrained(adapter_model_path)
model = PeftModel.from_pretrained(model, adapter_model_path)

model.merge_and_unload()

output_path = "/opt/ml/model"
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)

I noted issues #368 and PR #378, but these appear to address distributed training on normal model only. In the _save_xla method:

        if (
            not isinstance(self.model, NeuronPeftModel)
            and self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM
        ):
            if is_main_worker():
                logger.info(
                    "Model parallelism is enabled, saving the model sharded state dict instead of the full state dict."
                )
            # TODO: how to handle pp?
            if isinstance(self.model, PreTrainedModel):
                from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size

                config = copy.deepcopy(self.model.config)
                if self.args.mp_plugin.parallelize_embeddings:
                    config.vocab_size = config.vocab_size * get_tensor_model_parallel_size()
                config.save_pretrained(output_dir)

            # This mark_step is needed to avoid hang issues.
            xm.mark_step()
            Parallelizer.save_model_sharded_checkpoint(
                self.model,
                output_dir,
                optimizer=self.optimizer if not self.args.save_only_model else None,
                use_xser=self.accelerator.state.mp_plugin.use_xser,
                async_save=self.accelerator.state.mp_plugin.async_save,
                num_local_ranks_per_step=self.accelerator.state.mp_plugin.num_local_ranks_per_step,
            )
        else:
            supported_classes = (PreTrainedModel, NeuronPeftModel)
            if not isinstance(self.model, supported_classes):
                if isinstance(unwrap_model(self.model), supported_classes):
                    kwargs = (
                        {}
                        if isinstance(unwrap_model(self.model), PreTrainedModel)
                        else {"async_save": self.args.async_save}
                    )
                    unwrap_model(self.model).save_pretrained(
                        output_dir,
                        is_main_process=self.args.should_save,
                        state_dict=self.model.state_dict(),
                        save_function=xm.save,
                        **kwargs,
                    )
                else:
                    if is_main_worker():
                        logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                    state_dict = self.model.state_dict()
                    xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
            else:
                kwargs = {} if isinstance(self.model, PreTrainedModel) else {"async_save": self.args.async_save}
                self.model.save_pretrained(
                    output_dir, is_main_process=self.args.should_save, save_function=xm.save, **kwargs
                )

It seems that LoRA saving is not included in the changes from that PR, which may explain why the adapter model has a size of 64128 instead of 128256 after sharding and consolidation.

Could you provide guidance or suggest solutions for consolidating and compiling the adapter model successfully for inference? And please let me know if you need more information from me.

Thank you in advance for your assistance!

Expected behavior

I expected the adapter model to merge seamlessly with the base model without any errors. Additionally, the newly combined model should be functional and ready for inference without further issues.

@Kelv1nYu Kelv1nYu added the bug Something isn't working label Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant