Skip to content

[BUG]AutoTP train get AssertionError: Data inconsistency within the TP group. #7199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
liuteng opened this issue Apr 3, 2025 · 8 comments
Closed
Labels
bug Something isn't working training

Comments

@liuteng
Copy link

liuteng commented Apr 3, 2025

Describe the bug
when I use autotp for a sft train, get Data inconsistency error
I just use the accelerate(1.6.0) + deepspeed(0.16.5) + trl(0.16.0) + transformers(4.51.0.dev0 from the main branch) on a machine with 8xH20

I launch the train use
accelerate launch --config_file sft.yaml
sft.py
--deepspeed ds_config.json
--model_name_or_path /path/to/Qwen__QwQ-32B/
--dataset_name /path/to/my_data
--attn_implementation flash_attention_2
--bf16
--eval_steps 400
--gradient_checkpointing
--learning_rate 3.0e-6
--logging_dir /home/admin/
--log_level info
--logging_steps 2
--max_length 16384
--num_train_epochs 2
--output_dir /output
--run_name qwq-sft
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--save_strategy steps
--save_steps 1000000
--save_total_limit 20
--save_only_model
--seed 42
--warmup_ratio 0.1

the sft.yaml
compute_environment: LOCAL_MACHINE
debug: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

the ds_config.json
{
"zero_optimization": {
"stage": 2,
"gather_16bit_weights_on_model_save": true
},
"tensor_parallel": {
"autotp_size": 2
},
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 2
}

the sft.py like:

def main(script_args, training_args, model_args):

quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    attn_implementation=model_args.attn_implementation,
    torch_dtype=model_args.torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


data_files = [os.path.join(script_args.dataset_name, file) for file in os.listdir(script_args.dataset_name) if file.endswith('.json')]
dataset = load_dataset('json', data_files=data_files)['train'].train_test_split(test_size=0.1)

trainer = SFTTrainer(
    model=model_args.model_name_or_path,
    args=training_args,
    train_dataset=dataset[script_args.dataset_train_split],
    eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
    processing_class=tokenizer,
    peft_config=get_peft_config(model_args),
)

trainer.train()

def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser

if name == "main":
parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

and finally i got the error:

[rank1]: Traceback (most recent call last):
[rank1]: File "/ossfs/workspace/train/sft.py", line 133, in
[rank1]: main(script_args, training_args, model_args)
[rank1]: File "/ossfs/workspace/train/sft.py", line 113, in main
[rank1]: trainer.train()
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2245, in train
[rank1]: return inner_training_loop(
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2555, in _inner_training_loop
[rank1]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3731, in training_step
[rank1]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 495, in compute_loss
[rank1]: (loss, outputs) = super().compute_loss(
[rank1]: File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3796, in compute_loss
[rank1]: outputs = model(**inputs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2032, in forward
[rank1]: loss = self.module(*inputs, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank1]: return inner()
[rank1]: File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1772, in inner
[rank1]: args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
[rank1]: File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 477, in check_dataloader_inputs_same_across_ranks
[rank1]: broadcast_and_check(kwargs, bcast_rank, bcast_group)
[rank1]: File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 466, in broadcast_and_check
[rank1]: assert torch.equal(
[rank1]: AssertionError: Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency.

ds_report output
[2025-04-03 14:04:18,867] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)

DeepSpeed C++/CUDA extension op report

NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.

JIT compiled ops requires ninja
ninja .................. [OKAY]

op name ................ installed .. compatible

[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-devel package with yum
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
[WARNING] FP Quantizer is using an untested triton version (3.2.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
/opt/conda/compiler_compat/ld: warning: libm.so.6, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)
/opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to log2f@GLIBC_2.2.5' /opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to dlopen'
/opt/conda/compiler_compat/ld: /opt/conda/lib/python3.10/site-packages/aistudio_common/reader/libs//libstdc++.so.6: undefined reference to floor@GLIBC_2.2.5' /opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to dlclose'
/opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to dlerror' /opt/conda/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to dlsym'
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.6
[WARNING] using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]

DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.10/site-packages/torch']
torch version .................... 2.6.0+cu124
deepspeed install path ........... ['/opt/conda/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.16.5, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.6, cuda 12.4
shared memory (/dev/shm) size .... 768.06 GB

System info (please complete the following information):

  • OS: [Linux]
  • one machines with x8 H20
  • Python version 3.10
@liuteng liuteng added bug Something isn't working training labels Apr 3, 2025
@hwchen2017
Copy link
Contributor

cc: @inkcherry

@ekg
Copy link

ekg commented Apr 4, 2025

Is there a working example that uses AutoTP, bf16, and gradient clipping?

@inkcherry
Copy link
Contributor

Thanks for the report, I'm looking into it. @liuteng @hwchen2017 @ekg

@inkcherry
Copy link
Contributor

inkcherry commented Apr 7, 2025

Is there a working example that uses AutoTP, bf16, and gradient clipping?

hi, @ekg , you can try this~
deepspeedai/DeepSpeedExamples#964

@liuteng
Copy link
Author

liuteng commented Apr 8, 2025

I'm sorry, I finally found the cause of my error as the follow code:
dataset = load_dataset('json', data_files=data_files)['train'].train_test_split(test_size=0.1)
this cause data diff from processes

but there is another thing I want to know, I run into out of memory in my train with 8xH20 as I need a long context, so I think autotp can help me to solve this problem, and I find that when use autotp, as I set per_device_train_batch_size=1, but it will fetch 2 samples in one step, so should we support the real 1 batch size when use autotp?

@inkcherry

@inkcherry
Copy link
Contributor

inkcherry commented Apr 8, 2025

thanks for the debug~ @liuteng.
per_device_train_batch_size is equivalent to the micro batch size, which means the batch size on each rank. You can check the input shapes of computation ops — the batch dimension should actually be 1.The true training batch size can be determined based on the data parallel (DP) size and gradient accumulation steps (GAS). Fetching two samples might be related to prefetching behavior for fast loading data, but they shouldn't be treated as a batch of size 2 during computation.

If you're facing memory by limitations, you might consider increasing the tp_size.
Currently, reduce_scatter+allgather for tensor parallelism (aka tp with sp) is not supported, and contributions are welcome.
Another approach is to choose DeepSpeed Ulysses with ZeRO. ZeRO partitions the model, while Ulysses partitions the activations.

@liuteng
Copy link
Author

liuteng commented Apr 8, 2025

thanks for the suggests, I will try it.

@liuteng liuteng closed this as completed Apr 8, 2025
@lyx564
Copy link

lyx564 commented Apr 16, 2025

I'm sorry, I finally found the cause of my error as the follow code: dataset = load_dataset('json', data_files=data_files)['train'].train_test_split(test_size=0.1) this cause data diff from processes

but there is another thing I want to know, I run into out of memory in my train with 8xH20 as I need a long context, so I think autotp can help me to solve this problem, and I find that when use autotp, as I set per_device_train_batch_size=1, but it will fetch 2 samples in one step, so should we support the real 1 batch size when use autotp?

@inkcherry

May I ask how this problem was finally solved? @liuteng

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

5 participants