Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ class Args:
checkpointing_steps: int = 100
validation_steps: int = 100
# just placeholder, for logging purpose

# LoRA-specific arguments
use_lora: bool = False
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules: str = "q_proj"
#lora_target_modules: str = "all-linear" # Comma-separated list or "all-linear"

num_processes: int=0

def dict(self):
Expand Down
7 changes: 6 additions & 1 deletion F2LLM/configs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,10 @@
"warmup_steps": 500,
"train_epochs": 2,
"log_interval": 100,
"num_hard_neg": 7
"num_hard_neg": 7,
"use_lora": true,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_modules": "q_proj,v_proj,o_proj"
}
30 changes: 30 additions & 0 deletions F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from transformers import AutoModel, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType


class F2LLM:
Expand All @@ -14,8 +15,37 @@ def __init__(self,
self.device = None # set after accelerator.prepare
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')
self.lm.config.use_cache = False

if args and args.use_lora:
self._apply_lora()

self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length

def _apply_lora(self):
"""Apply LoRA to the model if enabled."""
# Process target modules
if self.args.lora_target_modules == "all-linear":
# For decoder-only models, common target modules are linear layers
target_modules = [
"q_proj", "v_proj", "k_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"lm_head"
]
else:
target_modules = [module.strip() for module in self.args.lora_target_modules.split(",")]

lora_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION, # Feature extraction for embedding models
r=self.args.lora_r,
lora_alpha=self.args.lora_alpha,
target_modules=target_modules,
lora_dropout=self.args.lora_dropout,
bias="none",
modules_to_save=[], # We don't need to save any additional modules
)

self.lm = get_peft_model(self.lm, lora_config)

def set_device(self):
self.device = self.lm.device
Expand Down
1 change: 1 addition & 0 deletions F2LLM/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ flash-attn
torch
transformers
tensorboard
peft
22 changes: 17 additions & 5 deletions F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,26 @@ def __iter__(self):

accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
model = F2LLM(args.model_path, args.max_seq_length, args=args)
model.lm.gradient_checkpointing_enable()
model.lm.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False} # key for loss error!
)
# set seed again to make sure that different models share the same seed
set_seed(0)

optimizer = AdamW(model.lm.parameters(),
weight_decay=args.weight_decay,
lr=args.learning_rate,
betas=(0.9, 0.98))
# Determine parameters for optimizer based on LoRA usage
if args.use_lora:
# Only optimize LoRA parameters if LoRA is enabled
optimizer = AdamW(model.lm.parameters(),
weight_decay=args.weight_decay,
lr=args.learning_rate,
betas=(0.9, 0.98))
accelerator.print(f"Using LoRA - optimizing {model.lm.num_parameters(only_trainable=True)} trainable parameters out of {model.lm.num_parameters()}")
else:
# Optimize all model parameters
optimizer = AdamW(model.lm.parameters(),
weight_decay=args.weight_decay,
lr=args.learning_rate,
betas=(0.9, 0.98))

lr_scheduler = get_scheduler("cosine",
optimizer=optimizer,
Expand Down