@@ -86,19 +86,22 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
8686 auto is_token_in_rank = at::zeros ({num_tokens, num_ranks}, at::dtype (at::kInt ).device (device));
8787 const int notify_send_data_size =
8888 num_experts * EXPERT_DATA_SIZE + server_num + MAX_BATCH_SIZE * (1 + 2 * server_num + num_topk);
89- /* The output parameters are ordered as follows
90- 1. the number of the tokens that every expert received from this NPU. size:[numExpert]
91- 2. The number of tokens received by each server from this NPU (deduplicated). size:[serverNum]
92- 3. The number of tokens sent from this NPU to each server (without deduplication). size:[MAX_BS, serverNum]
93- 4. The number of servers each token is sent to by this NPU. size:[MAX_BS]
94- 5. The order in which each token of this NPU is sent to various servers. size:[MAX_BS, serverNum]
95- 6. The order in which each token is sent to the expert. size:[MAX_BS, numTopk]
96- 7. The server offset of tokens received by each expert from this NPU. size:[numExpert, MAX_BS]
97- 8. The origin offset of the token received by each expert on the original NPU. size:[numExpert, MAX_BS]
98- */
89+ /*
90+ The output parameters are ordered as follows:
91+ 1. the number of the tokens that every expert received from this NPU. size:[numExpert]
92+ 2. The number of tokens received by each server from this NPU (deduplicated). size:[serverNum]
93+ 3. The number of tokens sent from this NPU to each server (without deduplication). size:[MAX_BS, serverNum]
94+ 4. The number of servers each token is sent to by this NPU. size:[MAX_BS]
95+ 5. The order in which each token of this NPU is sent to various servers. size:[MAX_BS, serverNum]
96+ 6. The order in which each token is sent to the expert. size:[MAX_BS, numTopk]
97+ 7. The server offset of tokens received by each expert from this NPU. size:[numExpert, MAX_BS]
98+ 8. The origin offset of the token received by each expert on the original NPU. size:[numExpert, MAX_BS]
99+ */
99100 auto notify_send_data = at::zeros ({notify_send_data_size}, at::dtype (at::kInt ).device (device));
100- notify_send_data.index ({at::indexing::Slice (num_experts + server_num + MAX_BATCH_SIZE * (server_num + 1 ),
101- num_experts + server_num + MAX_BATCH_SIZE * (server_num * 2 + 1 ))}).fill_ (-1 );
101+ notify_send_data
102+ .index ({at::indexing::Slice (num_experts + server_num + MAX_BATCH_SIZE * (server_num + 1 ),
103+ num_experts + server_num + MAX_BATCH_SIZE * (server_num * 2 + 1 ))})
104+ .fill_ (-1 );
102105 // The order of each token sent to the server is set to -1.
103106 EXEC_NPU_CMD (aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize,
104107 num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, notify_send_data);
0 commit comments