@@ -77,7 +77,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
7777
7878 const int num_tokens = new_topk_idx.size (0 );
7979 const int num_topk = new_topk_idx.size (1 );
80- const int local_ranksize = A2_LOCAL_RANK_SIZE ;
80+ const int local_ranksize = LOCAL_RANK_SIZE ;
8181 auto server_num = num_ranks / local_ranksize;
8282
8383 auto device = new_topk_idx.device ();
@@ -88,14 +88,22 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
8888 num_experts * EXPERT_DATA_SIZE + server_num + MAX_BATCH_SIZE * (1 + 2 * server_num + num_topk);
8989 /*
9090 The output parameters are ordered as follows:
91- 1. the number of the tokens that every expert received from this NPU. size:[numExpert]
92- 2. The number of tokens received by each server from this NPU (deduplicated). size:[serverNum]
93- 3. The number of tokens sent from this NPU to each server (without deduplication). size:[MAX_BS, serverNum]
94- 4. The number of servers each token is sent to by this NPU. size:[MAX_BS]
95- 5. The order in which each token of this NPU is sent to various servers. size:[MAX_BS, serverNum]
96- 6. The order in which each token is sent to the expert. size:[MAX_BS, numTopk]
97- 7. The server offset of tokens received by each expert from this NPU. size:[numExpert, MAX_BS]
98- 8. The origin offset of the token received by each expert on the original NPU. size:[numExpert, MAX_BS]
91+ 1. the number of the tokens that every expert received from this NPU.
92+ size:[numExpert]
93+ 2. The number of tokens received by each server from this NPU (deduplicated).
94+ size:[serverNum]
95+ 3. The number of tokens sent from this NPU to each server (without deduplication).
96+ size:[MAX_BS, serverNum]
97+ 4. The number of servers each token is sent to by this NPU.
98+ size:[MAX_BS]
99+ 5. The order in which each token of this NPU is sent to various servers.
100+ size:[MAX_BS, serverNum]
101+ 6. The order in which each token is sent to the expert.
102+ size:[MAX_BS, numTopk]
103+ 7. The server offset of tokens received by each expert from this NPU.
104+ size:[numExpert, MAX_BS]
105+ 8. The origin offset of the token received by each expert on the original NPU.
106+ size:[numExpert, MAX_BS]
99107 */
100108 auto notify_send_data = at::zeros ({notify_send_data_size}, at::dtype (at::kInt ).device (device));
101109 notify_send_data
0 commit comments