@@ -116,11 +116,13 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
116116 7. The server offset of tokens received by each expert from this NPU.
117117 size:[numExpert, MAX_BS]
118118 */
119+ auto send_token_idx_small = at::zeros ({num_tokens, num_topk}, at::dtype (at::kInt ).device (device));
119120 auto notify_send_data = at::zeros ({notify_send_data_size}, at::dtype (at::kInt ).device (device));
120121 EXEC_NPU_CMD (aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize,
121- num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, notify_send_data);
122+ num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, notify_send_data, send_token_idx_small );
122123
123124 this ->notify_send_data = notify_send_data;
125+ this ->send_token_idx_small = send_token_idx_small;
124126 this ->notify_send_data_size = notify_send_data_size;
125127
126128 std::optional<torch::Tensor> num_tokens_per_rdma_rank = std::nullopt ;
@@ -161,6 +163,19 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
161163 EP_HOST_ASSERT (config.num_sms % 2 == 0 );
162164 int num_channels = config.num_sms / 2 ;
163165
166+ at::Tensor expert_ids = new_topk_idx.to (at::kInt );
167+ int64_t tp_size = 1 ;
168+ int64_t tp_rank = 0 ;
169+ int64_t quant_mode = use_quant ? DYNAMIC_SCALES : NO_SCALES;
170+ auto recv_topk_idx = std::optional<at::Tensor>();
171+ auto recv_topk_weights = std::optional<at::Tensor>();
172+ // Wait streams
173+ std::optional<EventHandle> event;
174+ auto rank_prefix_matrix = at::empty ({num_ranks, num_ranks}, at::dtype (at::kInt ).device (x.device ()));
175+ auto channel_prefix_matrix = at::empty ({num_ranks, num_channels}, at::dtype (at::kInt ).device (x.device ()));
176+ auto recv_channel_prefix_matrix = at::empty ({num_ranks, num_channels}, at::dtype (at::kInt ).device (x.device ()));
177+ std::vector<int > num_recv_tokens_per_expert_list;
178+
164179 at::Tensor new_x = x;
165180 // for padding
166181 if (topk_idx->size (0 ) < PADDING_SIZE) {
@@ -240,7 +255,11 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
240255
241256 auto send_data_offset = torch::empty ({num_experts}, at::dtype (at::kInt ).device (x.device ()));
242257 at::Tensor recv_data = torch::empty ({num_experts * send_per_group}, at::dtype (at::kInt ).device (x.device ()));
243-
258+ at::Tensor total_recv_token_ = torch::empty ({1 }, at::dtype (at::kInt ).device (x.device ()));
259+ at::Tensor recv_count_ = torch::empty ({num_experts}, at::dtype (at::kInt ).device (x.device ()));
260+ at::Tensor recv_offset_ = torch::empty ({num_experts}, at::dtype (at::kInt ).device (x.device ()));
261+ at::Tensor max_bs_ = torch::empty ({1 }, at::dtype (at::kInt ).device (x.device ()));
262+ at::Tensor recv_tokens_per_expert_ = torch::empty ({num_local_experts}, at::dtype (at::kLong ).device (x.device ()));
244263 // get ep name
245264 char hcom_ep_name[HCOMM_NAME_LEN];
246265 if (!moe_all_to_all_group_name.empty ()) {
@@ -257,95 +276,33 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
257276 hcom_ep_name, // commGroup
258277 num_ranks, // rankSize
259278 rank, // rankId
260- local_rank_size, local_rank_id, send_data_offset, recv_data);
261-
262- auto options_cpu = torch::TensorOptions ().dtype (torch::kInt32 ).device (torch::kCPU );
263- std::vector<int32_t > local_expert_acc (num_experts, 0 );
264- auto send_token_idx_cpu = torch::empty ({num_tokens, num_topk}, options_cpu);
265- auto send_token_idx_ptr = send_token_idx_cpu.data_ptr <int >();
266-
267- auto topk_idx_cpu = new_topk_idx.to (at::kCPU );
268- auto topk_idx_ptr = topk_idx_cpu.data_ptr <int64_t >();
269- for (int i = 0 ; i < num_tokens; ++i) {
270- for (int j = 0 ; j < num_topk; ++j) {
271- int64_t expert_idx = topk_idx_ptr[i * num_topk + j];
272- if (expert_idx >= 0 ) {
273- int32_t cnt = local_expert_acc[expert_idx];
274- send_token_idx_ptr[i * num_topk + j] = cnt;
275- local_expert_acc[expert_idx]++;
276- }
277- }
278- }
279-
280- EP_HOST_ASSERT (recv_data.dim () == 1 and recv_data.is_contiguous ());
281- EP_HOST_ASSERT (recv_data.size (0 ) % num_experts == 0 );
282- at::Tensor recv_offset_cpu = torch::empty ({num_experts}, options_cpu);
283- at::Tensor recv_count_cpu = torch::empty ({num_experts}, options_cpu);
284- auto recv_data_cpu = recv_data.to (at::kCPU );
285- auto recv_data_ptr = recv_data_cpu.data_ptr <int >();
286- auto recv_count_ptr = recv_count_cpu.data_ptr <int >();
287- auto recv_offset_ptr = recv_offset_cpu.data_ptr <int >();
288- int total_recv_tokens = 0 ;
289- int num_max_dispatch_tokens_per_rank = 0 ;
290- std::vector<int > num_recv_tokens_per_expert_list;
291-
292- for (int64_t local_e = 0 ; local_e < num_local_experts; ++local_e) {
293- int64_t local_expert_recv_tokens = 0 ;
294- for (int64_t src_rank = 0 ; src_rank < num_ranks; ++src_rank) {
295- int64_t index = local_e * num_ranks + src_rank;
296- int64_t pair_idx = send_per_group * (src_rank * num_local_experts + local_e);
297-
298- int recv_cnt = recv_data_ptr[pair_idx]; // count from this src_rank for this global_expert
299- int recv_off = recv_data_ptr[pair_idx + 1 ]; // offset in that src_rank's window
300- int send_num_tokens = recv_data_ptr[pair_idx + 2 ]; // all bs from rank
301-
302- total_recv_tokens += recv_cnt;
303- recv_count_ptr[index] = total_recv_tokens;
304- recv_offset_ptr[index] = recv_off;
305- num_max_dispatch_tokens_per_rank = std::max (num_max_dispatch_tokens_per_rank, send_num_tokens);
306-
307- local_expert_recv_tokens += recv_cnt;
308- }
309- num_recv_tokens_per_expert_list.push_back (local_expert_recv_tokens);
310- }
311-
312- at::Tensor expert_ids = new_topk_idx.to (at::kInt );
313- int64_t tp_size = 1 ;
314- int64_t tp_rank = 0 ;
315- int64_t quant_mode = use_quant ? DYNAMIC_SCALES : NO_SCALES;
316- int64_t global_bs = static_cast <int64_t >(
317- std::max (num_max_dispatch_tokens_per_rank * num_ranks, static_cast <int64_t >(num_worst_tokens)));
318-
319- auto send_token_idx = send_token_idx_cpu.to (x.device ());
320- auto recv_offset = recv_offset_cpu.to (x.device ());
321- auto recv_count = recv_count_cpu.to (x.device ());
322-
323- int num_recv_tokens = (total_recv_tokens == 0 ) ? 1 : total_recv_tokens;
279+ local_rank_size, local_rank_id, send_data_offset, recv_data, total_recv_token_, recv_count_,
280+ recv_offset_, max_bs_, recv_tokens_per_expert_);
281+ auto send_token_idx_small = this ->send_token_idx_small ;
282+ int64_t gBs = max_bs_.item <int >() * num_ranks;
283+ int64_t trt = total_recv_token_.item <int >();
284+ int num_recv_tokens = (trt == 0 ) ? 1 : trt;
324285 auto expandx_out = use_quant ? torch::empty ({num_recv_tokens, hidden}, at::dtype (at::kChar ).device (x.device ()))
325286 : torch::empty ({num_recv_tokens, hidden}, x.options ());
326287 auto dynamic_scales_out = torch::empty ({num_recv_tokens}, at::dtype (at::kFloat ).device (x.device ()));
327288 auto expand_idx_out = torch::empty ({num_recv_tokens * 3 }, at::dtype (at::kInt ).device (x.device ()));
289+ if (topk_idx.has_value ()) {
290+ recv_topk_idx = at::empty ({trt, num_topk}, topk_idx->options ());
291+ recv_topk_weights = at::empty ({trt, num_topk}, topk_weights->options ());
292+ }
328293
329- EXEC_NPU_CMD (aclnnCamMoeDispatchNormal, new_x, expert_ids, send_data_offset, send_token_idx, recv_offset ,
330- recv_count , hcom_ep_name,
294+ EXEC_NPU_CMD (aclnnCamMoeDispatchNormal, new_x, expert_ids, send_data_offset, send_token_idx_small, recv_offset_ ,
295+ recv_count_ , hcom_ep_name,
331296 num_ranks, // rankSize
332297 rank, // rankId
333- hcom_ep_name, tp_size, tp_rank, num_experts, quant_mode, global_bs , expandx_out, dynamic_scales_out,
298+ hcom_ep_name, tp_size, tp_rank, num_experts, quant_mode, gBs , expandx_out, dynamic_scales_out,
334299 expand_idx_out, dispatch_wait_recv_cost_stats_out);
335-
336- auto recv_topk_idx = std::optional<at::Tensor>();
337- auto recv_topk_weights = std::optional<at::Tensor>();
338- if (topk_idx.has_value ()) {
339- recv_topk_idx = at::empty ({total_recv_tokens, num_topk}, topk_idx->options ());
340- recv_topk_weights = at::empty ({total_recv_tokens, num_topk}, topk_weights->options ());
300+ auto recv_token_per_exp_cpu = recv_tokens_per_expert_.to (at::kCPU );
301+ auto recv_token_per_exp_ptr = recv_token_per_exp_cpu.data_ptr <int64_t >();
302+ for (int64_t local_e = 0 ; local_e < num_local_experts; ++local_e) {
303+ int token_cnt = static_cast <int >(recv_token_per_exp_ptr[local_e]);
304+ num_recv_tokens_per_expert_list.emplace_back (token_cnt);
341305 }
342- // Wait streams
343- std::optional<EventHandle> event;
344-
345- auto rank_prefix_matrix = at::empty ({num_ranks, num_ranks}, at::dtype (at::kInt ).device (x.device ()));
346- auto channel_prefix_matrix = at::empty ({num_ranks, num_channels}, at::dtype (at::kInt ).device (x.device ()));
347- auto recv_channel_prefix_matrix = at::empty ({num_ranks, num_channels}, at::dtype (at::kInt ).device (x.device ()));
348-
349306 // Return values
350307 return {expandx_out,
351308 dynamic_scales_out,
@@ -356,7 +313,7 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
356313 channel_prefix_matrix,
357314 recv_channel_prefix_matrix,
358315 expand_idx_out,
359- recv_count ,
316+ recv_count_ ,
360317 event};
361318}
362319
0 commit comments