From ea2848cc1998f1da5b654cc0f9a26d3eefb8cf1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=BD=E5=90=89?= Date: Wed, 10 Dec 2025 11:14:33 +0800 Subject: [PATCH] Support for LoRA PEFT methods --- F2LLM/arguments.py | 9 +++++++++ F2LLM/configs/config.json | 7 ++++++- F2LLM/model.py | 30 ++++++++++++++++++++++++++++++ F2LLM/requirements.txt | 1 + F2LLM/run.py | 22 +++++++++++++++++----- 5 files changed, 63 insertions(+), 6 deletions(-) diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..ab06fde 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -28,6 +28,15 @@ class Args: checkpointing_steps: int = 100 validation_steps: int = 100 # just placeholder, for logging purpose + + # LoRA-specific arguments + use_lora: bool = False + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: str = "q_proj" + #lora_target_modules: str = "all-linear" # Comma-separated list or "all-linear" + num_processes: int=0 def dict(self): diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..cdb6ee5 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -15,5 +15,10 @@ "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + "use_lora": true, + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.05, + "lora_target_modules": "q_proj,v_proj,o_proj" } diff --git a/F2LLM/model.py b/F2LLM/model.py index d33ade7..565efc6 100644 --- a/F2LLM/model.py +++ b/F2LLM/model.py @@ -1,5 +1,6 @@ import torch from transformers import AutoModel, AutoTokenizer +from peft import LoraConfig, get_peft_model, TaskType class F2LLM: @@ -14,8 +15,37 @@ def __init__(self, self.device = None # set after accelerator.prepare self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2') self.lm.config.use_cache = False + + if args and args.use_lora: + self._apply_lora() + self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.max_seq_length = max_seq_length + + def _apply_lora(self): + """Apply LoRA to the model if enabled.""" + # Process target modules + if self.args.lora_target_modules == "all-linear": + # For decoder-only models, common target modules are linear layers + target_modules = [ + "q_proj", "v_proj", "k_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + "lm_head" + ] + else: + target_modules = [module.strip() for module in self.args.lora_target_modules.split(",")] + + lora_config = LoraConfig( + task_type=TaskType.FEATURE_EXTRACTION, # Feature extraction for embedding models + r=self.args.lora_r, + lora_alpha=self.args.lora_alpha, + target_modules=target_modules, + lora_dropout=self.args.lora_dropout, + bias="none", + modules_to_save=[], # We don't need to save any additional modules + ) + + self.lm = get_peft_model(self.lm, lora_config) def set_device(self): self.device = self.lm.device diff --git a/F2LLM/requirements.txt b/F2LLM/requirements.txt index 82fb447..5ef7b2e 100644 --- a/F2LLM/requirements.txt +++ b/F2LLM/requirements.txt @@ -5,3 +5,4 @@ flash-attn torch transformers tensorboard +peft \ No newline at end of file diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..d5722ee 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -120,14 +120,26 @@ def __iter__(self): accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") model = F2LLM(args.model_path, args.max_seq_length, args=args) -model.lm.gradient_checkpointing_enable() +model.lm.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} # key for loss error! +) # set seed again to make sure that different models share the same seed set_seed(0) -optimizer = AdamW(model.lm.parameters(), - weight_decay=args.weight_decay, - lr=args.learning_rate, - betas=(0.9, 0.98)) +# Determine parameters for optimizer based on LoRA usage +if args.use_lora: + # Only optimize LoRA parameters if LoRA is enabled + optimizer = AdamW(model.lm.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.98)) + accelerator.print(f"Using LoRA - optimizing {model.lm.num_parameters(only_trainable=True)} trainable parameters out of {model.lm.num_parameters()}") +else: + # Optimize all model parameters + optimizer = AdamW(model.lm.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.98)) lr_scheduler = get_scheduler("cosine", optimizer=optimizer,