You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction (minimal, reproducible, runnable)
I was following the tutorial for Supervised Fine-Tuning of Llama 3 8B on one AWS Trainium instance. Since I used an ml.trn1.2xlarge AWS instance, which only has two Neuron cores, I modified the parameters and model accordingly. Below is the training script I used with my dataset:
Training Code:
fromdataclassesimportdataclass, fieldfromdatasetsimportload_datasetfrompeftimportLoraConfigfromtransformersimport (
AutoModelForCausalLM,
AutoTokenizer,
set_seed,
)
fromoptimum.neuronimportNeuronHfArgumentParserasHfArgumentParserfromoptimum.neuronimportNeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArgumentsfromoptimum.neuron.distributedimportlazy_load_for_parallelismdefformat_medchat(examples):
output_text= []
foriinrange(len(examples["question"])):
question=f"### Question\n{examples['question'][i]}"answer=f"### Answer\n{examples['answer'][i]}"prompt="\n\n".join([textfortextin [question, answer]])
output_text.append(prompt)
returnoutput_textdeftraining_function(script_args, training_args):
dataset=load_dataset("ngram/medchat-qa", split="train")
tokenizer=AutoTokenizer.from_pretrained(script_args.model_id)
tokenizer.pad_token=tokenizer.eos_tokenwithlazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
model=AutoModelForCausalLM.from_pretrained(script_args.model_id)
config=LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.05,
target_modules=[
"q_proj",
"gate_proj",
"v_proj",
"o_proj",
"k_proj",
"up_proj",
"down_proj"
],
bias="none",
task_type="CAUSAL_LM",
)
args=training_args.to_dict()
sft_config=NeuronSFTConfig(
max_seq_length=1024,
packing=False,
**args,
)
trainer=NeuronSFTTrainer(
args=sft_config,
model=model,
peft_config=config,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=format_medchat,
)
# Start trainingtrainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload@dataclassclassScriptArguments:
model_id: str=field(
default="meta-llama/Meta-Llama-3-8B",
metadata={"help": "The model that you want to train from the Hugging Face hub."},
)
defmain():
parser=HfArgumentParser([ScriptArguments, NeuronTrainingArguments])
script_args, training_args=parser.parse_args_into_dataclasses()
# set seedset_seed(training_args.seed)
# run training functiontraining_function(script_args, training_args)
if__name__=="__main__":
main()
After training, the tutorial guided me to consolidate the model using optimum-cli and I consolidated the output folder using this command:
optimum-cli neuron consolidate output output
After the consolidation, when trying to compile the adapter model for inference, I encountered the following error:
ValueError: The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`.
So I attempted to merge the LoRA adapter model to the base model but received this error:
Traceback (most recent call last):
File "/opt/ml/code/load_model.py", line 25, in <module>
model = PeftModel.from_pretrained(model, adapter_model_path)
File "/usr/local/lib/python3.10/site-packages/peft/peft_model.py", line 586, in from_pretrained
model.load_adapter(
File "/usr/local/lib/python3.10/site-packages/peft/peft_model.py", line 1181, in load_adapter
load_result = set_peft_model_state_dict(
File "/usr/local/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 460, in set_peft_model_state_dict
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
#011size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([64128, 2048]) from checkpoint, the shape in current model is torch.Size([128256, 2048]).
I noted issues #368 and PR #378, but these appear to address distributed training on normal model only. In the _save_xla method:
if (
notisinstance(self.model, NeuronPeftModel)
andself.accelerator.distributed_typeisNeuronDistributedType.MODEL_PARALLELISM
):
ifis_main_worker():
logger.info(
"Model parallelism is enabled, saving the model sharded state dict instead of the full state dict."
)
# TODO: how to handle pp?ifisinstance(self.model, PreTrainedModel):
fromneuronx_distributed.parallel_layers.parallel_stateimportget_tensor_model_parallel_sizeconfig=copy.deepcopy(self.model.config)
ifself.args.mp_plugin.parallelize_embeddings:
config.vocab_size=config.vocab_size*get_tensor_model_parallel_size()
config.save_pretrained(output_dir)
# This mark_step is needed to avoid hang issues.xm.mark_step()
Parallelizer.save_model_sharded_checkpoint(
self.model,
output_dir,
optimizer=self.optimizerifnotself.args.save_only_modelelseNone,
use_xser=self.accelerator.state.mp_plugin.use_xser,
async_save=self.accelerator.state.mp_plugin.async_save,
num_local_ranks_per_step=self.accelerator.state.mp_plugin.num_local_ranks_per_step,
)
else:
supported_classes= (PreTrainedModel, NeuronPeftModel)
ifnotisinstance(self.model, supported_classes):
ifisinstance(unwrap_model(self.model), supported_classes):
kwargs= (
{}
ifisinstance(unwrap_model(self.model), PreTrainedModel)
else {"async_save": self.args.async_save}
)
unwrap_model(self.model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
**kwargs,
)
else:
ifis_main_worker():
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict=self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
kwargs= {} ifisinstance(self.model, PreTrainedModel) else {"async_save": self.args.async_save}
self.model.save_pretrained(
output_dir, is_main_process=self.args.should_save, save_function=xm.save, **kwargs
)
It seems that LoRA saving is not included in the changes from that PR, which may explain why the adapter model has a size of 64128 instead of 128256 after sharding and consolidation.
Could you provide guidance or suggest solutions for consolidating and compiling the adapter model successfully for inference? And please let me know if you need more information from me.
Thank you in advance for your assistance!
Expected behavior
I expected the adapter model to merge seamlessly with the base model without any errors. Additionally, the newly combined model should be functional and ready for inference without further issues.
The text was updated successfully, but these errors were encountered:
System Info
Who can help?
@michaelbenayoun
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction (minimal, reproducible, runnable)
I was following the tutorial for Supervised Fine-Tuning of Llama 3 8B on one AWS Trainium instance. Since I used an
ml.trn1.2xlarge
AWS instance, which only has two Neuron cores, I modified the parameters and model accordingly. Below is the training script I used with my dataset:Training Code:
Training Shell Script:
After training, the tutorial guided me to consolidate the model using
optimum-cli
and I consolidated theoutput
folder using this command:After the consolidation, when trying to compile the adapter model for inference, I encountered the following error:
So I attempted to merge the LoRA adapter model to the base model but received this error:
Merge Code:
I noted issues #368 and PR #378, but these appear to address distributed training on normal model only. In the _save_xla method:
It seems that LoRA saving is not included in the changes from that PR, which may explain why the adapter model has a size of 64128 instead of 128256 after sharding and consolidation.
Could you provide guidance or suggest solutions for consolidating and compiling the adapter model successfully for inference? And please let me know if you need more information from me.
Thank you in advance for your assistance!
Expected behavior
I expected the adapter model to merge seamlessly with the base model without any errors. Additionally, the newly combined model should be functional and ready for inference without further issues.
The text was updated successfully, but these errors were encountered: