Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError in _group_tensors_by_device_and_dtype (torch/optim/optimizer.py) when training with FSDP on N>1 GPUs. #34730

Open
2 of 4 tasks
julien-piet opened this issue Nov 14, 2024 · 0 comments
Labels

Comments

@julien-piet
Copy link

julien-piet commented Nov 14, 2024

System Info

  • transformers version: 4.46.2
  • Platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: Yes (FSDP)
  • Using GPU in script?: Yes
  • GPU type: NVIDIA RTX A5000

Who can help?

@muellerzr @sunm

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Output Error

[rank1]:   File "/data/julien_piet/llm-attack-detect/scripts/bug.py", line 254, in <module>
[rank1]:     train(
[rank1]:   File "/data/julien_piet/llm-attack-detect/scripts/bug.py", line 247, in train
[rank1]:     trainer.train()
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/transformers/trainer.py", line 2123, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/transformers/trainer.py", line 2534, in _inner_training_loop
[rank1]:     self.optimizer.step()
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/accelerate/optimizer.py", line 171, in step
[rank1]:     self.optimizer.step(closure)
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank1]:     return func.__get__(opt, opt.__class__)(*args, **kwargs)
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank1]:     out = func(*args, **kwargs)
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
[rank1]:     ret = func(self, *args, **kwargs)
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
[rank1]:     adamw(
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
[rank1]:     func(
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/adamw.py", line 480, in _multi_tensor_adamw
[rank1]:     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 516, in _group_tensors_by_device_and_dtype
[rank1]:     return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)  # type: ignore[return-value, arg-type]
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/utils/_foreach_utils.py", line 37, in _group_tensors_by_device_and_dtype
[rank1]:     return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
[rank1]: RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/julien_piet/llm-attack-detect/scripts/bug.py", line 254, in <module>
[rank0]:     train(
[rank0]:   File "/data/julien_piet/llm-attack-detect/scripts/bug.py", line 247, in train
[rank0]:     trainer.train()
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/transformers/trainer.py", line 2123, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/transformers/trainer.py", line 2534, in _inner_training_loop
[rank0]:     self.optimizer.step()
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/accelerate/optimizer.py", line 171, in step
[rank0]:     self.optimizer.step(closure)
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank0]:     return func.__get__(opt, opt.__class__)(*args, **kwargs)
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank0]:     out = func(*args, **kwargs)
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
[rank0]:     ret = func(self, *args, **kwargs)
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
[rank0]:     adamw(
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
[rank0]:     func(
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/adamw.py", line 480, in _multi_tensor_adamw
[rank0]:     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 516, in _group_tensors_by_device_and_dtype
[rank0]:     return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)  # type: ignore[return-value, arg-type]
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/data/julien_piet/llm-attack-detect/env/lib/python3.10/site-packages/torch/utils/_foreach_utils.py", line 37, in _group_tensors_by_device_and_dtype
[rank0]:     return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
[rank0]: RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding

Original Code

import os
from dataclasses import dataclass, field
from typing import Any, Optional

import numpy as np
import torch
from datasets import Dataset
from sklearn.utils.class_weight import compute_class_weight
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
)


@dataclass
class CustomTrainingArguments(TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(default=1024)
    lr_scheduler_type: Optional[str] = field(default="cosine_with_restarts")
    per_device_train_batch_size: int = field(default=4)
    per_device_eval_batch_size: int = field(default=4)
    output_dir: Optional[str] = field(default="output")
    remove_unused_columns: bool = False


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(
        default="meta-llama/Llama-3.2-1B-Instruct"
    )
    pad_token: str = field(
        default="<|finetune_right_pad_id|>", metadata={"help": "Padding token."}
    )
    unk_token: str = field(
        default="<|reserved_special_token_0|>",
        metadata={"help": "Unknown token."},
    )


class SupervisedDataset:
    def __init__(self, data, tokenizer, training_args):
        data_dict = SupervisedDataset._preprocess(
            data, tokenizer, training_args
        )
        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        self.attention_mask = data_dict["attention_mask"]
        self.classification_labels = [
            d["messages"][-1]["content"] for d in data
        ]

        # Compute class weights for imbalanced classes
        self.class_weights, self.class_values, self.class_indices = (
            SupervisedDataset.get_class_weights(
                self.classification_labels, tokenizer
            )
        )
        self.classification_labels = [
            self.class_indices[label] for label in self.classification_labels
        ]

    @staticmethod
    def get_class_weights(labels, tokenizer):
        classes = sorted(list(set(labels)))
        class_indices = {label: idx for idx, label in enumerate(classes)}
        label_indices = [class_indices[label] for label in labels]

        class_values = []
        for class_name in classes:
            class_values.append(
                tokenizer.encode(class_name, add_special_tokens=False)[0]
            )

        class_weights = compute_class_weight(
            class_weight="balanced",
            classes=np.unique(label_indices),
            y=label_indices,
        )
        return class_weights, class_values, class_indices

    @staticmethod
    def _preprocess(data, tokenizer, training_args):
        formatted_inputs = [
            tokenizer.apply_chat_template(d["messages"], tokenize=False)
            for d in data
        ]
        formatted_prompts = [
            tokenizer.apply_chat_template(
                d["messages"][:-1], tokenize=False, add_generation_prompt=True
            )
            for d in data
        ]
        tokenized_inputs = tokenizer(
            formatted_inputs,
            padding=True,
            padding_side="left",
            return_tensors="pt",
            add_special_tokens=False,
        )
        tokenized_prompts = tokenizer(
            formatted_prompts,
            padding=True,
            padding_side="left",
            return_tensors="pt",
            add_special_tokens=False,
        )

        attention_mask = tokenized_prompts["attention_mask"]
        input_ids = tokenized_prompts["input_ids"]
        labels = tokenized_inputs["input_ids"][
            :, tokenized_prompts["input_ids"].shape[1]
        ]

        attention_mask = attention_mask[:, -training_args.model_max_length :]
        input_ids = input_ids[:, -training_args.model_max_length :]

        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    def convert_to_dataset(self):
        return Dataset.from_dict(
            {
                "input_ids": self.input_ids,
                "labels": self.labels,
                "attention_mask": self.attention_mask,
            }
        )


# Custom Trainer with weighted loss
class WeightedLoss:
    def __init__(self, class_weights=None, class_values=None):
        self.class_weights = torch.tensor(class_weights).cuda()
        self.class_values = class_values

    def compute_loss(self, outputs, labels, **kwargs):
        logits = outputs.get("logits")

        # Compute loss based on last token logits
        logits = logits[:, -1, self.class_values].reshape(
            -1, len(self.class_values)
        )

        ce_labels = torch.tensor(
            [self.class_values.index(v) for v in labels]
        ).to(labels.device)

        if self.class_weights.dtype != logits.dtype:
            self.class_weights = self.class_weights.to(logits.dtype)

        loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        loss = loss_fct(logits, ce_labels)

        return loss


# Load and prepare the dataset
def load_and_prepare_data(training_args, tokenizer):
    dataset = [
        {
            "messages": [
                {
                    "role": "user",
                    "content": (
                        "Please respond with " + ("no" if i % 2 else "yes")
                    ),
                },
                {"role": "assistant", "content": "no" if i % 2 else "yes"},
            ]
        }
        for i in range(1000)
    ]

    dataset = SupervisedDataset(
        dataset,
        tokenizer,
        training_args,
    )

    class_weights, class_values, class_indices = (
        dataset.class_weights,
        dataset.class_values,
        dataset.class_indices,
    )

    dataset = dataset.convert_to_dataset()

    return (
        dataset,
        None,
        class_weights,
        class_values,
        class_indices,
    )


# Training function
def train(model_args, training_args):

    if training_args.lr_scheduler_type == "cosine_with_restarts":
        training_args.lr_scheduler_kwargs = {
            "num_cycles": 1 + training_args.num_train_epochs // 10
        }

    # Load the pretrained model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        truncation_side="left",
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )

    # Augment tokenizer
    if tokenizer.pad_token is None:
        tokenizer.pad_token = model_args.pad_token
    if tokenizer.unk_token is None:
        tokenizer.unk_token = model_args.unk_token

    # Load and prepare data
    train_dataset, _, class_weights, class_values, _ = load_and_prepare_data(
        training_args, tokenizer
    )

    # Loss function
    custom_loss = WeightedLoss(class_weights, class_values)

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        compute_loss_func=lambda x, y, **kwargs: custom_loss.compute_loss(x, y),
    )

    # Start training
    trainer.train()
    trainer.save_model(output_dir=training_args.output_dir)


if __name__ == "__main__":
    parser = HfArgumentParser((ModelArguments, CustomTrainingArguments))
    model_args, training_args = parser.parse_args_into_dataclasses()
    train(
        model_args,
        training_args,
    )

Command

Command that triggers the error (considering the previous code is in a file called bug.py)

torchrun --nproc_per_node=2 --master_port=19527 bug.py  --model_name_or_path meta-llama/Llama-3.2-1B-Instruct           --output_dir outputs/test/ --num_train_epochs 5 --per_device_train_batch_size 4  --model_max_length 1024 --gradient_accumulation_steps 8 --evaluation_strategy "no" --save_strategy "no" --save_total_limit 1 --learning_rate 2.5e-6             --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type "cosine_with_restarts" --logging_steps 1 --fsdp "full_shard auto_wrap" --fsdp_transformer_layer_cls_to_wrap "LlamaDecoderLayer" --bf16 True --tf32 True

Expected behavior

I'm trying to fine-tune a model using the Trainer library. I am using TorchRun with FSDP to distribute the training over multiple GPUs. If I run the provided code with a single process, it works fine. However, if I increase nproc_per_node, I get the error provided with the example.

This error first seemed to be a PyTorch error, for which I created an issue here. However, as pointed out by @JoyceZhangSS and @jiaqiw09, this is an issue related to transformers version 4.46.2: 4.46.1 does not have this bug, and training happens as expected.

I reproduced the error in a standalone file with a dummy dataset, provided in this issue. However, it occurs with any dataset, and with the standard loss: the default alpaca training code leads to the same error, both with Llama and OPT models. I did some investigation into the issue that might be helpful:

  • I was able to reproduce this error in multiple environments.
  • This error is triggered during the optimization step. Specifically, it is triggered by this snippet of code in torch/optim/adamw.py:480:
 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]  # type: ignore[list-item]
    )
  • The problems stems from grads being bf16 while the rest are float32 --- in transformers==4.46.1, all groups are float32.
  • I looked at the diff between both versions and found the change responsible for the bug. In trainer.py, you did the following change:
2473 -                   with self.accelerator.accumulate(model):
2474 +                   # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
2475 +                   context = (
2476 +                        functools.partial(self.accelerator.no_sync, model=model)
2477 +                        if i == len(batch_samples) - 1
2478 +                        else contextlib.nullcontext
2479 +                   )
2480 +                   with context():

This context seems responsible for syncing the gradients across devices, so I tried reverting this change, and the error stops happening. I don't know enough about this to understand what the context does precisely, or why you do not want to rely on it, but removing seems to be what broke the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant