Skip to content

Discrepancy in Training Loss Behavior with Gradient Accumulation using DeepSpeed #34694

Closed
@kmchiti

Description

@kmchiti

System Info

Accelerate version: 1.1.0
transformers version: 4.46.2
DeepSpeed version: 0.14.4
Platform: Linux 5.15.0-101-generic #111-Ubuntu SMP x86_64 GNU/Linux
Python version: 3.10.14
PyTorch version (GPU?): 2.1.2+cu118 True
GPU type: NVIDIA A100

Who can help?

@muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The code provided below is a simplified example of training a small model using the Hugging Face Trainer. The setup includes creating a dataset, initializing a model and tokenizer, and configuring the Trainer with different settings for gradient accumulation and DeepSpeed.

import argparse
import torch
from datasets import load_dataset
from transformers import (set_seed,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          LlamaForCausalLM,
                          LlamaConfig,
                          AutoTokenizer
                          )


DEEPSPEED_CONFIG = {
  "zero_optimization": {
    "stage": 0
  },
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": "auto",
      "betas": "auto",
      "eps": "auto",
      "weight_decay": "auto"
    }
  },
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto"
}

TRAIN_ARGS = {'output_dir': './test_GA',
              'bf16': True,
              'learning_rate': 6e-4,
              'lr_scheduler_type': 'cosine',
              'max_steps': 200,
              'optim': 'adamw_torch',
              'weight_decay': 0.1,
              'per_device_train_batch_size': 128,
              'gradient_accumulation_steps': 1,
              'logging_steps': 1,
              'report_to': 'none'}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', default=128, type=int, help='batch size')
    parser.add_argument('--ga', default=1, type=int, help='number of gradient accumulation step')
    parser.add_argument('--deepspeed', action='store_true', help='use deepspeed')
    args = parser.parse_args()

    set_seed(42)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False

    # Initialize dataset
    CONTEXT_LENGTH = 512  # Small context length as specified
    def preprocess_data(examples, tokenizer, max_length=CONTEXT_LENGTH):
        """Tokenizes the input data and truncates/pads to the max context length."""
        return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=max_length, add_special_tokens=True)

    # Load the dataset from Hugging Face
    dataset = load_dataset("ptb_text_only", trust_remote_code=True, split='train')

    # Load and configure the tokenizer
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_prefix_space=True, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token

    # Preprocess the dataset
    column_names = list(dataset.features)
    train_dataset = dataset.map(lambda x: preprocess_data(x, tokenizer), batched=True, remove_columns=column_names)

    # Initialize model
    model_cfg = LlamaConfig(n_positions=CONTEXT_LENGTH, hidden_size=512, num_attention_heads=8, num_hidden_layers=4,
                            vocab_size=tokenizer.vocab_size, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id)
    model = LlamaForCausalLM(model_cfg)
    
    # Initialize trainer
    if args.deepspeed:
        TRAIN_ARGS.update({"deepspeed": DEEPSPEED_CONFIG})
    TRAIN_ARGS.update({"per_device_train_batch_size": args.bs, "gradient_accumulation_steps": args.ga})
    trainer = Trainer(model=model, args=TrainingArguments(**TRAIN_ARGS), train_dataset=train_dataset,
                      data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))
    trainer.train()

Expected behavior

The training loss should remain consistent for different gradient accumulation steps, both with and without DeepSpeed enabled. However, the figure shows a divergence when DeepSpeed is enabled:

gradient_accumulation_issue

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions