From cc759872e9805ca0ca72d53d7ca536b1396e1467 Mon Sep 17 00:00:00 2001 From: offline0806 <3337230449@qq.com> Date: Thu, 23 Oct 2025 15:12:25 +0800 Subject: [PATCH 1/2] [BugFix]Check all expert maps when using muilty instance. Signed-off-by: offline0806 <3337230449@qq.com> --- vllm_ascend/ops/common_fused_moe.py | 1 + vllm_ascend/ops/expert_load_balancer.py | 18 ++++++++++++++++++ vllm_ascend/torchair/ops/torchair_fused_moe.py | 1 + 3 files changed, 20 insertions(+) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 1335be5ada..604418cc57 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -192,6 +192,7 @@ def __init__(self, *args, **kwargs): os.R_OK): self.expert_load_balancer = ExpertLoadBalancer( self.expert_map_path, self.global_num_experts) + self.expert_load_balancer.check_expert_map_tensor() self.global_redundant_expert_num = ( self.expert_load_balancer.get_global_redundant_expert_num()) try: diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py index c6eec64a36..74b18d48e2 100644 --- a/vllm_ascend/ops/expert_load_balancer.py +++ b/vllm_ascend/ops/expert_load_balancer.py @@ -3,6 +3,7 @@ from typing import Dict, List import torch +import torch.distributed as dist class ExpertLoadBalancer(object): @@ -97,3 +98,20 @@ def get_global_redundant_expert_num(self): len(self.expert_map_tensor[0][0]) * self.ranks_num - self.global_expert_num) return global_redundant_expert_num + + def check_expert_map_tensor(self): + if dist.is_initialized(): + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + all_expert_maps = [None for _ in range(world_size)] + dist.all_gather_object(all_expert_maps, self.tensor_data) + for rank_id, expert_map_tensor in enumerate(all_expert_maps): + if self.tensor_data != expert_map_tensor: + raise ValueError( + f"The expert map of rank{rank} is not equal to rank{rank_id}" + ) + return True + except Exception as e: + raise ValueError( + f"The expert maps of all ranks are inconsistency: {e}") diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 9a07e8cae9..2e9e8fa288 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -1042,6 +1042,7 @@ def __init__( os.R_OK): self.expert_load_balancer = ExpertLoadBalancer( self.expert_map_path, self.global_num_experts) + self.expert_load_balancer.check_expert_map_tensor() self.global_redundant_expert_num = ( self.expert_load_balancer.get_global_redundant_expert_num()) try: From 1eb902ebaad502491357bc2d28119f141af41ab4 Mon Sep 17 00:00:00 2001 From: offline0806 <3337230449@qq.com> Date: Thu, 23 Oct 2025 15:18:02 +0800 Subject: [PATCH 2/2] [BugFix]change tensor data to class param. Signed-off-by: offline0806 <3337230449@qq.com> --- vllm_ascend/ops/expert_load_balancer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py index 74b18d48e2..de6a7c55a4 100644 --- a/vllm_ascend/ops/expert_load_balancer.py +++ b/vllm_ascend/ops/expert_load_balancer.py @@ -11,8 +11,10 @@ class ExpertLoadBalancer(object): def __init__(self, expert_map_path, global_expert_num): self.expert_map_path = expert_map_path self.global_expert_num = global_expert_num + self.tensor_data = [] self.expert_map_tensor, self.layers_num, self.ranks_num = ( self._expert_file_to_tensor()) + self.expert_placement_map = self.generate_expert_placement_map() def _expert_file_to_tensor(self): with open(self.expert_map_path, "r") as f: @@ -20,13 +22,12 @@ def _expert_file_to_tensor(self): layers_num = data["moe_layer_count"] gpus_num = data["layer_list"][0]["device_count"] - tensor_data = [] for layer in data["layer_list"]: device_data = [] for device in layer["device_list"]: device_data.append(device["device_expert"]) - tensor_data.append(device_data) - expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + self.tensor_data.append(device_data) + expert_map_tensor = torch.tensor(self.tensor_data, dtype=torch.int32) return expert_map_tensor, layers_num, gpus_num def generate_index_dicts(self, tensor_2d): @@ -82,8 +83,7 @@ def generate_log2phy_expert_map(self, layer_id): return log2phy_map def get_rank_placement_map(self, layer_id, rank_id): - expert_placement_map = self.generate_expert_placement_map() - layer_expert_map = expert_placement_map[layer_id] + layer_expert_map = self.expert_placement_map[layer_id] rank_expert_map = layer_expert_map[rank_id].to( torch.npu.current_device()) rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item()