Skip to content

Commit 3085bab

Browse files
authored
calculate dispatch normal input parameters using npu instead of cpu (#177)
* move cpu calculation to notify_dispatch & intranode_dispatch adpation * calculate sendtokenidx{bs * topk} in layout * codecheck fix & remove redundent code * sync layout modifications to ops2 * remove kernel printf * change recv_num_tokens_per_expert_list dtype to int64_t * change magicVal from int32_t to uint64_t * change recv_num_tokens_per_expert_list to List[int]
1 parent 7fe4eeb commit 3085bab

26 files changed

+449
-197
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 40 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,13 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
116116
7. The server offset of tokens received by each expert from this NPU.
117117
size:[numExpert, MAX_BS]
118118
*/
119+
auto send_token_idx_small = at::zeros({num_tokens, num_topk}, at::dtype(at::kInt).device(device));
119120
auto notify_send_data = at::zeros({notify_send_data_size}, at::dtype(at::kInt).device(device));
120121
EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize,
121-
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, notify_send_data);
122+
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, notify_send_data, send_token_idx_small);
122123

123124
this->notify_send_data = notify_send_data;
125+
this->send_token_idx_small = send_token_idx_small;
124126
this->notify_send_data_size = notify_send_data_size;
125127

126128
std::optional<torch::Tensor> num_tokens_per_rdma_rank = std::nullopt;
@@ -161,6 +163,19 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
161163
EP_HOST_ASSERT(config.num_sms % 2 == 0);
162164
int num_channels = config.num_sms / 2;
163165

166+
at::Tensor expert_ids = new_topk_idx.to(at::kInt);
167+
int64_t tp_size = 1;
168+
int64_t tp_rank = 0;
169+
int64_t quant_mode = use_quant ? DYNAMIC_SCALES : NO_SCALES;
170+
auto recv_topk_idx = std::optional<at::Tensor>();
171+
auto recv_topk_weights = std::optional<at::Tensor>();
172+
// Wait streams
173+
std::optional<EventHandle> event;
174+
auto rank_prefix_matrix = at::empty({num_ranks, num_ranks}, at::dtype(at::kInt).device(x.device()));
175+
auto channel_prefix_matrix = at::empty({num_ranks, num_channels}, at::dtype(at::kInt).device(x.device()));
176+
auto recv_channel_prefix_matrix = at::empty({num_ranks, num_channels}, at::dtype(at::kInt).device(x.device()));
177+
std::vector<int> num_recv_tokens_per_expert_list;
178+
164179
at::Tensor new_x = x;
165180
// for padding
166181
if (topk_idx->size(0) < PADDING_SIZE) {
@@ -240,7 +255,11 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
240255

241256
auto send_data_offset = torch::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
242257
at::Tensor recv_data = torch::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
243-
258+
at::Tensor total_recv_token_ = torch::empty({1}, at::dtype(at::kInt).device(x.device()));
259+
at::Tensor recv_count_ = torch::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
260+
at::Tensor recv_offset_ = torch::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
261+
at::Tensor max_bs_ = torch::empty({1}, at::dtype(at::kInt).device(x.device()));
262+
at::Tensor recv_tokens_per_expert_ = torch::empty({num_local_experts}, at::dtype(at::kLong).device(x.device()));
244263
// get ep name
245264
char hcom_ep_name[HCOMM_NAME_LEN];
246265
if (!moe_all_to_all_group_name.empty()) {
@@ -257,95 +276,33 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
257276
hcom_ep_name, // commGroup
258277
num_ranks, // rankSize
259278
rank, // rankId
260-
local_rank_size, local_rank_id, send_data_offset, recv_data);
261-
262-
auto options_cpu = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
263-
std::vector<int32_t> local_expert_acc(num_experts, 0);
264-
auto send_token_idx_cpu = torch::empty({num_tokens, num_topk}, options_cpu);
265-
auto send_token_idx_ptr = send_token_idx_cpu.data_ptr<int>();
266-
267-
auto topk_idx_cpu = new_topk_idx.to(at::kCPU);
268-
auto topk_idx_ptr = topk_idx_cpu.data_ptr<int64_t>();
269-
for (int i = 0; i < num_tokens; ++i) {
270-
for (int j = 0; j < num_topk; ++j) {
271-
int64_t expert_idx = topk_idx_ptr[i * num_topk + j];
272-
if (expert_idx >= 0) {
273-
int32_t cnt = local_expert_acc[expert_idx];
274-
send_token_idx_ptr[i * num_topk + j] = cnt;
275-
local_expert_acc[expert_idx]++;
276-
}
277-
}
278-
}
279-
280-
EP_HOST_ASSERT(recv_data.dim() == 1 and recv_data.is_contiguous());
281-
EP_HOST_ASSERT(recv_data.size(0) % num_experts == 0);
282-
at::Tensor recv_offset_cpu = torch::empty({num_experts}, options_cpu);
283-
at::Tensor recv_count_cpu = torch::empty({num_experts}, options_cpu);
284-
auto recv_data_cpu = recv_data.to(at::kCPU);
285-
auto recv_data_ptr = recv_data_cpu.data_ptr<int>();
286-
auto recv_count_ptr = recv_count_cpu.data_ptr<int>();
287-
auto recv_offset_ptr = recv_offset_cpu.data_ptr<int>();
288-
int total_recv_tokens = 0;
289-
int num_max_dispatch_tokens_per_rank = 0;
290-
std::vector<int> num_recv_tokens_per_expert_list;
291-
292-
for (int64_t local_e = 0; local_e < num_local_experts; ++local_e) {
293-
int64_t local_expert_recv_tokens = 0;
294-
for (int64_t src_rank = 0; src_rank < num_ranks; ++src_rank) {
295-
int64_t index = local_e * num_ranks + src_rank;
296-
int64_t pair_idx = send_per_group * (src_rank * num_local_experts + local_e);
297-
298-
int recv_cnt = recv_data_ptr[pair_idx]; // count from this src_rank for this global_expert
299-
int recv_off = recv_data_ptr[pair_idx + 1]; // offset in that src_rank's window
300-
int send_num_tokens = recv_data_ptr[pair_idx + 2]; // all bs from rank
301-
302-
total_recv_tokens += recv_cnt;
303-
recv_count_ptr[index] = total_recv_tokens;
304-
recv_offset_ptr[index] = recv_off;
305-
num_max_dispatch_tokens_per_rank = std::max(num_max_dispatch_tokens_per_rank, send_num_tokens);
306-
307-
local_expert_recv_tokens += recv_cnt;
308-
}
309-
num_recv_tokens_per_expert_list.push_back(local_expert_recv_tokens);
310-
}
311-
312-
at::Tensor expert_ids = new_topk_idx.to(at::kInt);
313-
int64_t tp_size = 1;
314-
int64_t tp_rank = 0;
315-
int64_t quant_mode = use_quant ? DYNAMIC_SCALES : NO_SCALES;
316-
int64_t global_bs = static_cast<int64_t>(
317-
std::max(num_max_dispatch_tokens_per_rank * num_ranks, static_cast<int64_t>(num_worst_tokens)));
318-
319-
auto send_token_idx = send_token_idx_cpu.to(x.device());
320-
auto recv_offset = recv_offset_cpu.to(x.device());
321-
auto recv_count = recv_count_cpu.to(x.device());
322-
323-
int num_recv_tokens = (total_recv_tokens == 0) ? 1 : total_recv_tokens;
279+
local_rank_size, local_rank_id, send_data_offset, recv_data, total_recv_token_, recv_count_,
280+
recv_offset_, max_bs_, recv_tokens_per_expert_);
281+
auto send_token_idx_small = this->send_token_idx_small;
282+
int64_t gBs = max_bs_.item<int>() * num_ranks;
283+
int64_t trt = total_recv_token_.item<int>();
284+
int num_recv_tokens = (trt == 0) ? 1 : trt;
324285
auto expandx_out = use_quant ? torch::empty({num_recv_tokens, hidden}, at::dtype(at::kChar).device(x.device()))
325286
: torch::empty({num_recv_tokens, hidden}, x.options());
326287
auto dynamic_scales_out = torch::empty({num_recv_tokens}, at::dtype(at::kFloat).device(x.device()));
327288
auto expand_idx_out = torch::empty({num_recv_tokens * 3}, at::dtype(at::kInt).device(x.device()));
289+
if (topk_idx.has_value()) {
290+
recv_topk_idx = at::empty({trt, num_topk}, topk_idx->options());
291+
recv_topk_weights = at::empty({trt, num_topk}, topk_weights->options());
292+
}
328293

329-
EXEC_NPU_CMD(aclnnCamMoeDispatchNormal, new_x, expert_ids, send_data_offset, send_token_idx, recv_offset,
330-
recv_count, hcom_ep_name,
294+
EXEC_NPU_CMD(aclnnCamMoeDispatchNormal, new_x, expert_ids, send_data_offset, send_token_idx_small, recv_offset_,
295+
recv_count_, hcom_ep_name,
331296
num_ranks, // rankSize
332297
rank, // rankId
333-
hcom_ep_name, tp_size, tp_rank, num_experts, quant_mode, global_bs, expandx_out, dynamic_scales_out,
298+
hcom_ep_name, tp_size, tp_rank, num_experts, quant_mode, gBs, expandx_out, dynamic_scales_out,
334299
expand_idx_out, dispatch_wait_recv_cost_stats_out);
335-
336-
auto recv_topk_idx = std::optional<at::Tensor>();
337-
auto recv_topk_weights = std::optional<at::Tensor>();
338-
if (topk_idx.has_value()) {
339-
recv_topk_idx = at::empty({total_recv_tokens, num_topk}, topk_idx->options());
340-
recv_topk_weights = at::empty({total_recv_tokens, num_topk}, topk_weights->options());
300+
auto recv_token_per_exp_cpu = recv_tokens_per_expert_.to(at::kCPU);
301+
auto recv_token_per_exp_ptr = recv_token_per_exp_cpu.data_ptr<int64_t>();
302+
for (int64_t local_e = 0; local_e < num_local_experts; ++local_e) {
303+
int token_cnt = static_cast<int>(recv_token_per_exp_ptr[local_e]);
304+
num_recv_tokens_per_expert_list.emplace_back(token_cnt);
341305
}
342-
// Wait streams
343-
std::optional<EventHandle> event;
344-
345-
auto rank_prefix_matrix = at::empty({num_ranks, num_ranks}, at::dtype(at::kInt).device(x.device()));
346-
auto channel_prefix_matrix = at::empty({num_ranks, num_channels}, at::dtype(at::kInt).device(x.device()));
347-
auto recv_channel_prefix_matrix = at::empty({num_ranks, num_channels}, at::dtype(at::kInt).device(x.device()));
348-
349306
// Return values
350307
return {expandx_out,
351308
dynamic_scales_out,
@@ -356,7 +313,7 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
356313
channel_prefix_matrix,
357314
recv_channel_prefix_matrix,
358315
expand_idx_out,
359-
recv_count,
316+
recv_count_,
360317
event};
361318
}
362319

csrc/deepep/deep_ep.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ struct Buffer {
2929
at::Tensor new_topk_idx;
3030
at::Tensor new_scales;
3131
at::Tensor notify_send_data; // only for internode notify
32-
int notify_send_data_size; // only for internode notify
32+
at::Tensor send_token_idx_small;
33+
int notify_send_data_size; // only for internode notify
3334

3435
int64_t shared_expert_rank_num;
3536
int64_t shared_expert_num = 1;

csrc/deepep/ops/op_host/dispatch_layout.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ class DispatchLayout : public OpDef
3838
.DataType({ge::DT_INT32})
3939
.Format({ge::FORMAT_ND})
4040
.UnknownShapeFormat({ge::FORMAT_ND});
41+
this->Output("sendTokenIdxSmall")
42+
.ParamType(REQUIRED)
43+
.DataType({ge::DT_INT32})
44+
.Format({ge::FORMAT_ND})
45+
.UnknownShapeFormat({ge::FORMAT_ND});
4146

4247
OpAICoreConfig a3_config;
4348
a3_config.DynamicCompileStaticFlag(true)

csrc/deepep/ops/op_host/notify_dispatch.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,31 @@ class NotifyDispatch : public OpDef
2626
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
2727
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
2828
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
29-
29+
this->Output("totalRecvTokens")
30+
.ParamType(REQUIRED)
31+
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
32+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
33+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
34+
this->Output("recvCount")
35+
.ParamType(REQUIRED)
36+
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
37+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
38+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
39+
this->Output("recvOffset")
40+
.ParamType(REQUIRED)
41+
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
42+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
43+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
44+
this->Output("maxBs")
45+
.ParamType(REQUIRED)
46+
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32})
47+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
48+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
49+
this->Output("recvTokensPerExpert")
50+
.ParamType(REQUIRED)
51+
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
52+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
53+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
3054
this->Attr("sendCount").Int();
3155
this->Attr("num_tokens").Int();
3256
this->Attr("comm_group").String();

csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(const aclTensor *topkIdx, int64_
1818
int64_t numExperts, int64_t numTopk, int64_t localRankSize,
1919
const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert,
2020
const aclTensor *isTokenInRank, const aclTensor *notifySendData,
21-
uint64_t *workspaceSize, aclOpExecutor **executor)
21+
const aclTensor *sendTokenIdxSmall, uint64_t *workspaceSize,
22+
aclOpExecutor **executor)
2223
{
2324
return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, localRankSize,
2425
numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData,
25-
workspaceSize, executor);
26+
sendTokenIdxSmall, workspaceSize, executor);
2627
}
2728

2829
aclnnStatus aclnnDispatchLayout(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ extern "C" {
2424
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(
2525
const aclTensor *topkIdx, int64_t numTokens, int64_t numRanks, int64_t numExperts, int64_t numTopk,
2626
int64_t localRankSize, const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert,
27-
const aclTensor *isTokenInRank, const aclTensor *notifySendData, uint64_t *workspaceSize, aclOpExecutor **executor);
27+
const aclTensor *isTokenInRank, const aclTensor *notifySendData, const aclTensor *sendTokenIdxSmall,
28+
uint64_t *workspaceSize, aclOpExecutor **executor);
2829

2930
/* function: aclnnDispatchLayout
3031
* workspace : workspace memory addr(input).

csrc/deepep/ops/op_host/op_api/aclnn_notify_dispatch.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@ aclnnStatus aclnnNotifyDispatchGetWorkspaceSize(const aclTensor *sendData, const
2020
int64_t sendCount, int64_t numTokens, char *commGroup, int64_t rankSize,
2121
int64_t rankId, int64_t localRankSize, int64_t localRankId,
2222
const aclTensor *sendDataOffset, const aclTensor *recvData,
23-
uint64_t *workspaceSize, aclOpExecutor **executor)
23+
const aclTensor *totalRecvTokens, const aclTensor *recvCount,
24+
const aclTensor *recvOffset, const aclTensor *maxBs,
25+
const aclTensor *recvTokensPerExpert, uint64_t *workspaceSize,
26+
aclOpExecutor **executor)
2427
{
2528
return aclnnInnerNotifyDispatchGetWorkspaceSize(sendData, tokenPerExpertData, sendCount, numTokens, commGroup,
2629
rankSize, rankId, localRankSize, localRankId, sendDataOffset,
27-
recvData, workspaceSize, executor);
30+
recvData, totalRecvTokens, recvCount, recvOffset, maxBs,
31+
recvTokensPerExpert, workspaceSize, executor);
2832
}
2933

3034
aclnnStatus aclnnNotifyDispatch(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

csrc/deepep/ops/op_host/op_api/aclnn_notify_dispatch.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ extern "C" {
2727
__attribute__((visibility("default"))) aclnnStatus aclnnNotifyDispatchGetWorkspaceSize(
2828
const aclTensor *sendData, const aclTensor *tokenPerExpertData, int64_t sendCount, int64_t numTokens,
2929
char *commGroup, int64_t rankSize, int64_t rankId, int64_t localRankSize, int64_t localRankId,
30-
const aclTensor *sendDataOffset, const aclTensor *recvData, uint64_t *workspaceSize, aclOpExecutor **executor);
30+
const aclTensor *sendDataOffset, const aclTensor *recvData, const aclTensor *totalRecvTokens,
31+
const aclTensor *recvCount, const aclTensor *recvOffset, const aclTensor *maxBs,
32+
const aclTensor *recvTokensPerExpert, uint64_t *workspaceSize, aclOpExecutor **executor);
3133

3234
/* function: aclnnNotifyDispatch
3335
* parameters :

csrc/deepep/ops/op_kernel/dispatch_layout.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR numTokensPerRank,
1010
GM_ADDR numTokensPerExpert, GM_ADDR isTokenInRank,
11-
GM_ADDR notifySendData, GM_ADDR workspace, GM_ADDR tiling)
11+
GM_ADDR notifySendData, GM_ADDR sendTokenIdxSmall,
12+
GM_ADDR workspace, GM_ADDR tiling)
1213
{
1314
REGISTER_TILING_DEFAULT(DispatchLayoutTilingData);
1415
GET_TILING_DATA_WITH_STRUCT(DispatchLayoutTilingData, tilingData, tiling);
@@ -17,13 +18,13 @@ extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR n
1718

1819
if (TILING_KEY_IS(TILING_KEY_INT)) {
1920
MoeDispatchLayout::DispatchLayout<int32_t> op;
20-
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, workspace, &pipe,
21-
&tilingData);
21+
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, sendTokenIdxSmall,
22+
workspace, &pipe, &tilingData);
2223
op.Process();
2324
} else if (TILING_KEY_IS(TILING_KEY_A2_INT)) {
2425
MoeDispatchLayoutA2::DispatchLayoutA2<int32_t> op;
25-
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, workspace, &pipe,
26-
&tilingData);
26+
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, sendTokenIdxSmall,
27+
workspace, &pipe, &tilingData);
2728
op.Process();
2829
}
2930
}

0 commit comments

Comments
 (0)