diff --git a/F2LLM/GRADIENT_ACCUMULATION_README.md b/F2LLM/GRADIENT_ACCUMULATION_README.md new file mode 100644 index 0000000..3f43124 --- /dev/null +++ b/F2LLM/GRADIENT_ACCUMULATION_README.md @@ -0,0 +1,53 @@ +# Gradient Accumulation in F2LLM + +## How Gradient Accumulation Works in This Codebase + +1. Set `gradient_accumulation_steps` in the config.json and arguments.py file (default is 1, meaning no accumulation) + - e.g: `"gradient_accumulation_steps": 4` will accumulate gradients over 4 micro-batches + + +2. `utils.py`: + ```python + # Scale loss by gradient accumulation steps to maintain same effective learning rate + loss_total = loss_total / args.gradient_accumulation_steps + + # Update step only after gradient_accumulation_steps + if (completed_steps + 1) % args.gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + ``` + - Without accumulation: Process 1 batch of size N → compute loss → update parameters + - With accumulation: Process 4 micro-batches of size N/4 → accumulate gradients → update parameters + + Both result in same parameter update if learning rate is properly scaled + + +## Example + +Let's say you have: +- Desired effective batch size: 32 +- GPU memory only allows: 8 samples per batch + +**Without Gradient Accumulation**: +- You're limited to batch size 8 +- Effective batch size = 8 +- May result in suboptimal training dynamics + +**With Gradient Accumulation (steps=4)**: +- Process 4 micro-batches of size 8 each +- Effective batch size = 32 (4 × 8) +- Same training dynamics as a batch size of 32 +- Better gradient estimates due to larger effective batch size + +## Configuration Example + +To use gradient accumulation, modify your config file: +```json +{ + "train_batch_size": 8, + "gradient_accumulation_steps": 4, + // This gives you an effective batch size of 32 (8 * 4) + // while only using memory for 8 samples at a time +} +``` \ No newline at end of file diff --git a/F2LLM/README.md b/F2LLM/README.md index 6b79819..b0adba9 100644 --- a/F2LLM/README.md +++ b/F2LLM/README.md @@ -27,11 +27,15 @@ In this repo we provide a streamlined and efficient script for training embeddin - Setup environment following `requirements.txt`. We note that transformers>=4.51.0 is required for training Qwen3 models. - Download data and backbone models from Hugging Face (we use Qwen3 models). - Run `tokenize_data_qwen.py` to tokenize the downloaded data -- Modify model path, data path, and other arguments in `configs/config.json`. +- Modify model path, data path, and other arguments in `configs/config.json`. Note that you can configure gradient accumulation using the `gradient_accumulation_steps` parameter to enable training with larger effective batch sizes on resource-constrained hardware. - Start training with `accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json`. Note: we recommend setting `num_processes` to 1 in `configs/accelerate_config.yaml` and launch the training code once to generate cache for training data before starting the actual training. +### Gradient Accumulation + +The training script supports gradient accumulation to enable training with larger effective batch sizes on resource-constrained hardware. This feature allows users to simulate large batch training by accumulating gradients over multiple smaller batches before performing optimization steps. Configure gradient accumulation by setting the `gradient_accumulation_steps` parameter in your config file - the default value is 1 (no accumulation). For example, with `train_batch_size=8` and `gradient_accumulation_steps=4`, the effective batch size becomes 32. + For multi-node training, run on the main node: ``` diff --git a/F2LLM/__pycache__/arguments.cpython-313.pyc b/F2LLM/__pycache__/arguments.cpython-313.pyc new file mode 100644 index 0000000..f6c42de Binary files /dev/null and b/F2LLM/__pycache__/arguments.cpython-313.pyc differ diff --git a/F2LLM/__pycache__/model.cpython-313.pyc b/F2LLM/__pycache__/model.cpython-313.pyc new file mode 100644 index 0000000..6009551 Binary files /dev/null and b/F2LLM/__pycache__/model.cpython-313.pyc differ diff --git a/F2LLM/__pycache__/run.cpython-313.pyc b/F2LLM/__pycache__/run.cpython-313.pyc new file mode 100644 index 0000000..dbdf9f2 Binary files /dev/null and b/F2LLM/__pycache__/run.cpython-313.pyc differ diff --git a/F2LLM/__pycache__/utils.cpython-313.pyc b/F2LLM/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000..62dc0c1 Binary files /dev/null and b/F2LLM/__pycache__/utils.cpython-313.pyc differ diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..77d1a01 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -27,6 +27,8 @@ class Args: log_interval: int = 20 checkpointing_steps: int = 100 validation_steps: int = 100 + # gradient accumulation + gradient_accumulation_steps: int = 1 # just placeholder, for logging purpose num_processes: int=0 diff --git a/F2LLM/configs/config.json b/F2LLM/configs/config.json index 2ac3708..7b8505b 100644 --- a/F2LLM/configs/config.json +++ b/F2LLM/configs/config.json @@ -15,5 +15,6 @@ "warmup_steps": 500, "train_epochs": 2, "log_interval": 100, - "num_hard_neg": 7 + "num_hard_neg": 7, + "gradient_accumulation_steps": 1 } diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..0731f58 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -134,7 +134,9 @@ def __iter__(self): num_warmup_steps=args.warmup_steps, num_training_steps=args.train_steps) -AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size +if AcceleratorState().deepspeed_plugin is not None: + AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size + AcceleratorState().deepspeed_plugin.deepspeed_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps model.lm, optimizer, lr_scheduler = accelerator.prepare( model.lm, optimizer, lr_scheduler ) diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..4d48beb 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -124,7 +124,8 @@ def accelerate_train(args, accelerator.print(f" Num train samples = {num_train_samples}") accelerator.print(f" Num epochs = {args.train_epochs}") accelerator.print(f" Per device batch size = {args.train_batch_size}") - accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes}") + accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" Global batch size = {args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps}") accelerator.print(f" Step per epoch = {len(train_dataloader)}") accelerator.print(f" Total training steps = {args.train_steps}") accelerator.print("************************************************************************************************") @@ -165,14 +166,20 @@ def accelerate_train(args, loss_total = loss + loss_hard - # backward, optimizer, scheduler + # Scale loss by gradient accumulation steps to maintain same effective learning rate + loss_total = loss_total / args.gradient_accumulation_steps + + # backward accelerator.backward(loss_total) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - if optimizer.param_groups[0]['lr'] < args.min_lr: - for i in range(len(optimizer.param_groups)): - optimizer.param_groups[i]['lr'] = args.min_lr + + # Update step only after gradient_accumulation_steps + if (completed_steps + 1) % args.gradient_accumulation_steps == 0 or (completed_steps + 1) == args.train_steps: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + if optimizer.param_groups[0]['lr'] < args.min_lr: + for i in range(len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = args.min_lr # log completed_steps += 1 @@ -180,14 +187,15 @@ def accelerate_train(args, pbar.update(args.log_interval) train_log_dict = {"lr": optimizer.param_groups[0]['lr']} + # Scale losses back by gradient accumulation steps for logging for k in loss_dict.keys(): count = accelerator.gather(count_dict[k]).sum() if count > 0: - train_log_dict[f"{k}/training_loss_in_batch"] = accelerator.gather(loss_dict[k]).sum() / count + train_log_dict[f"{k}/training_loss_in_batch"] = (accelerator.gather(loss_dict[k]).sum() / count) * args.gradient_accumulation_steps for k in loss_hard_dict.keys(): count = accelerator.gather(count_hard_dict[k]).sum() if count > 0: - train_log_dict[f"{k}/training_loss_hard"] = accelerator.gather(loss_hard_dict[k]).sum() / count + train_log_dict[f"{k}/training_loss_hard"] = (accelerator.gather(loss_hard_dict[k]).sum() / count) * args.gradient_accumulation_steps train_log_dict['Avg/retrieval/training_loss_in_batch'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_in_batch')]).mean() train_log_dict['Avg/retrieval/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in RETRIEVAL_DATASETS and k.endswith('training_loss_hard')]).mean() train_log_dict['Avg/classification/training_loss_hard'] = torch.tensor([v for k, v in train_log_dict.items() if k.split('/')[0] in CLASSIFICATION_DATASETS]).mean()