Skip to content

[BUG]DeepSpeed MoE hangs with DDP inference #7141

Open
@JessePrince

Description

@JessePrince

Describe the bug
I'm trying to combine DDP with a trained MoE model using deepspeed moe. I set the ep_size to 1 and there is no tensor parallel. The way to enable DDP is to launch with deepspeed and initialize process groups, the data is sliced manually to different ranks, so each rank will have an identical model with different data batches.

However, I noticed that all to all communication inside the MoE layer hangs even the ep_size is 1, my goal is to run processes independently so I can do inference in parallel, the data is distributed to different ranks and I can gather them after all processes finish their generation.

By the way, the inference runs fine if I use just one GPU for inference.

To Reproduce

def distributed_eval(loader, model, generation_config, tokenizer):
    all_batches = []
    for data in loader:
        all_batches.append(data)
    dist.barrier()
    output = []
    cnt = 0  
    num_proc = dist.get_world_size()
    batch_for_this_rank = all_batches[local_rank::num_proc]
    
    pbar = tqdm(total=len(batch_for_this_rank), desc=f"[rank{local_rank}]Inference")
    for each in batch_for_this_rank:
        input_ids, graphs = each["input_ids"].to(args.device), each["graphs"].to(args.device)
        output_ids = model.generate(
            input_ids,
            graphs=graphs,
            do_sample=True,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            repetition_penalty=args.repetition_penalty,
            use_cache=True,
            attention_mask=each["attention_mask"].to(args.device),
            this_task_ids=each["this_task_ids"].to(args.device),
            generation_config=generation_config
        )
        
        for idx, (result, input_id, prompt, gt) in enumerate(zip(output_ids, input_ids, each["prompt"], each["gt"])):
            this_output = {
                "prompt": prompt,
                "gt": gt,
                "pred": tokenizer.decode(result[input_id.shape[0]:])
            }
            output.append(this_output)
            if cnt < 10:
                pbar.write(f"\n[rank {local_rank}]{this_output}\n")
            
        cnt += 1
        pbar.update(1)
            
    logger.info("Gathering object from processes...")
    all_output = [None for _ in range(dist.get_world_size())]
    all_output = dist.all_gather_object(all_output, output)
    all_output = [element for each_out in all_output for element in each_out]
    dist.barrier()
    
    return output

Expected behavior
I expect that the processes won't hang

System info (please complete the following information):

  • OS: Ubuntu 22.04
  • One node with 4xRTX 3090
  • DeepSpeed 0.15.1
  • torch 2.1
  • CUDA 12.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions