diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index d9344ea2..4ee2de10 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -12,6 +12,10 @@ constexpr int PADDING_SIZE = 3; constexpr size_t HCOMM_NAME_LEN = 128; constexpr uint32_t NO_SCALES = 0; constexpr uint32_t DYNAMIC_SCALES = 2; +// In a shared header +constexpr int LOCAL_RANK_SIZE = 8; +constexpr int MAX_BATCH_SIZE = 4096; +constexpr int EXPERT_DATA_SIZE = 1 + 2 * MAX_BATCH_SIZE; // 8193 Buffer::Buffer(int64_t rank, int64_t num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, std::string moe_all_to_all_group_name) @@ -73,15 +77,44 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std: const int num_tokens = new_topk_idx.size(0); const int num_topk = new_topk_idx.size(1); + const int local_ranksize = LOCAL_RANK_SIZE; + auto server_num = num_ranks / local_ranksize; auto device = new_topk_idx.device(); auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device)); auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device)); - auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, at::dtype(at::kInt).device(device)); - - EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, num_tokens_per_rank, - num_tokens_per_expert, is_token_in_rank); - + auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device)); + const int notify_send_data_size = + num_experts * EXPERT_DATA_SIZE + server_num + MAX_BATCH_SIZE * (1 + 2 * server_num + num_topk); + /* + The notify send data is constructed by 8 parameters and the 8 parameters are ordered as follows: + 1. the number of the tokens that every expert received from this NPU. + size:[numExpert] + 2. The number of tokens received by each server from this NPU (deduplicated). + size:[serverNum] + 3. The number of tokens sent from this NPU to each server (without deduplication). + size:[MAX_BS, serverNum] + 4. The number of servers each token is sent to by this NPU. + size:[MAX_BS] + 5. The order in which each token of this NPU is sent to various servers. + size:[MAX_BS, serverNum] + 6. The order in which each token is sent to the expert. + size:[MAX_BS, numTopk] + 7. The server offset of tokens received by each expert from this NPU. + size:[numExpert, MAX_BS] + 8. The origin offset of the token received by each expert on the original NPU. + size:[numExpert, MAX_BS] + */ + auto notify_send_data = at::zeros({notify_send_data_size}, at::dtype(at::kInt).device(device)); + notify_send_data + .index({at::indexing::Slice(num_experts + server_num + MAX_BATCH_SIZE * (server_num + 1), + num_experts + server_num + MAX_BATCH_SIZE * (server_num * 2 + 1))}) + .fill_(-1); + // The order of each token sent to the server is set to -1. + EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize, + num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, notify_send_data); + + this->notify_send_data = notify_send_data; std::optional num_tokens_per_rdma_rank = std::nullopt; std::optional output_event = std::nullopt; auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool); diff --git a/csrc/deepep/deep_ep.hpp b/csrc/deepep/deep_ep.hpp index e4a760a1..baa51d5a 100644 --- a/csrc/deepep/deep_ep.hpp +++ b/csrc/deepep/deep_ep.hpp @@ -26,6 +26,7 @@ struct Buffer { at::Tensor ori_x; at::Tensor new_topk_idx; at::Tensor new_scales; + at::Tensor notify_send_data; int64_t shared_expert_rank_num; int64_t shared_expert_num = 1; diff --git a/csrc/deepep/ops/op_host/dispatch_layout.cpp b/csrc/deepep/ops/op_host/dispatch_layout.cpp index 3091521c..9f2de4c6 100644 --- a/csrc/deepep/ops/op_host/dispatch_layout.cpp +++ b/csrc/deepep/ops/op_host/dispatch_layout.cpp @@ -16,6 +16,7 @@ class DispatchLayout : public OpDef this->Attr("num_ranks").Int(); this->Attr("num_experts").Int(); this->Attr("num_topk").Int(); + this->Attr("local_ranksize").Int(); this->Output("numTokensPerRank") .ParamType(REQUIRED) @@ -32,9 +33,14 @@ class DispatchLayout : public OpDef .DataType({ge::DT_INT32}) .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("notifySendData") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); - OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true) + OpAICoreConfig a3_config; + a3_config.DynamicCompileStaticFlag(true) .DynamicFormatFlag(true) .DynamicRankSupportFlag(true) .DynamicShapeSupportFlag(true) @@ -44,7 +50,19 @@ class DispatchLayout : public OpDef .ExtendCfgInfo("jitCompile.flag", "static_true") .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); - this->AICore().AddConfig("ascend910_93", aicore_config); + OpAICoreConfig a2_config; + a2_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + + this->AICore().AddConfig("ascend910_93", a3_config); + this->AICore().AddConfig("ascend910b", a2_config); } }; diff --git a/csrc/deepep/ops/op_host/dispatch_layout_tiling.cc b/csrc/deepep/ops/op_host/dispatch_layout_tiling.cc index efbd0227..6adaa211 100644 --- a/csrc/deepep/ops/op_host/dispatch_layout_tiling.cc +++ b/csrc/deepep/ops/op_host/dispatch_layout_tiling.cc @@ -26,17 +26,24 @@ constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0; constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0; constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1; constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2; +constexpr uint32_t OUTPUT_NOTIFY_SEND_DATA_INDEX = 3; constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0; constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1; constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2; constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3; +constexpr uint32_t ATTR_LOCAL_RANKSIZE_INDEX = 4; const int64_t MAX_COMM_WORLD_SIZE = 384; const int64_t MAX_MOE_EXPERTS_NUM = 512; +const int64_t MAX_LOCAL_RANKSIZE = 8; + constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024; constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024; +constexpr static int TILING_KEY_INT = 23; +constexpr static int TILING_KEY_A2_TYPE = 100; + constexpr uint32_t TWO_DIMS = 2; constexpr uint32_t K_MAX = 16; } // namespace @@ -48,9 +55,24 @@ static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData & OP_LOGD(nodeName, "numRanks is %u.", tilingData.dispatchLayoutInfo.numRanks); OP_LOGD(nodeName, "numExperts is %u.", tilingData.dispatchLayoutInfo.numExperts); OP_LOGD(nodeName, "numTopk is %u.", tilingData.dispatchLayoutInfo.numTopk); + OP_LOGD(nodeName, "localRankSize is %u.", tilingData.dispatchLayoutInfo.localRankSize); OP_LOGD(nodeName, "totalUbSize is %lu.", tilingData.dispatchLayoutInfo.totalUbSize); } +static bool CheckIfA2Machine(gert::TilingContext *context) +{ + fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); + fe::PlatFormInfos &platformInfo = *platformInfoPtr; + + std::string socVersion; + (void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion); + + if (socVersion == "Ascend910B") { + return true; + } + return false; +} + static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, DispatchLayoutTilingData &tilingData) { @@ -61,11 +83,14 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con auto numRanksPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_RANKS_INDEX)); auto numExpertsPtr = attrs->GetAttrPointer(ATTR_NUM_EXPERTS_INDEX); auto numTopkPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_TOPK_INDEX)); + auto localRankSizePtr = attrs->GetAttrPointer(static_cast(ATTR_LOCAL_RANKSIZE_INDEX)); OP_TILING_CHECK(numTokensPtr == nullptr, OP_LOGE(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED); OP_TILING_CHECK(numRanksPtr == nullptr, OP_LOGE(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED); OP_TILING_CHECK(numExpertsPtr == nullptr, OP_LOGE(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED); OP_TILING_CHECK(numTopkPtr == nullptr, OP_LOGE(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(localRankSizePtr == nullptr, OP_LOGE(nodeName, "localRankSizePtr is null."), + return ge::GRAPH_FAILED); OP_TILING_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE), OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", @@ -80,10 +105,19 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con OP_LOGE(nodeName, "numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld.", K_MAX, *numTopkPtr), return ge::GRAPH_FAILED); + if (CheckIfA2Machine(context)) { + OP_TILING_CHECK( + (*localRankSizePtr <= 0) || (*localRankSizePtr > MAX_LOCAL_RANKSIZE), + OP_LOGE(nodeName, "localRankSizePtr is invalid, only support (0, %ld], but got localRankSize=%ld.", + MAX_LOCAL_RANKSIZE, *localRankSizePtr), + return ge::GRAPH_FAILED); + } + tilingData.dispatchLayoutInfo.numTokens = static_cast(*numTokensPtr); tilingData.dispatchLayoutInfo.numRanks = static_cast(*numRanksPtr); tilingData.dispatchLayoutInfo.numExperts = static_cast(*numExpertsPtr); tilingData.dispatchLayoutInfo.numTopk = static_cast(*numTopkPtr); + tilingData.dispatchLayoutInfo.localRankSize = static_cast(*localRankSizePtr); return ge::GRAPH_SUCCESS; } @@ -102,11 +136,13 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa auto numTokensPerRank = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_RANK_INDEX); auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX); auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX); + auto notifySendData = context->GetOutputDesc(OUTPUT_NOTIFY_SEND_DATA_INDEX); OP_TILING_CHECK(topkIdx == nullptr, OP_LOGE(nodeName, "topkIdx is null."), return false); OP_TILING_CHECK(numTokensPerRank == nullptr, OP_LOGE(nodeName, "numTokensPerRank is null."), return false); OP_TILING_CHECK(numTokensPerExpert == nullptr, OP_LOGE(nodeName, "numTokensPerExpert is null."), return false); OP_TILING_CHECK(isTokenInRank == nullptr, OP_LOGE(nodeName, "isTokenInRank is null."), return false); + OP_TILING_CHECK(notifySendData == nullptr, OP_LOGE(nodeName, "notifySendData is null."), return false); OP_TILING_CHECK((topkIdx->GetDataType() != ge::DT_INT64), OP_LOGE(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.", @@ -124,6 +160,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa OP_LOGE(nodeName, "isTokenInRank datatype is invalid, datatype should be int, but is %d.", static_cast(isTokenInRank->GetDataType())), return false); + OP_TILING_CHECK((notifySendData->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "notifySendData datatype is invalid, datatype should be int, but is %d.", + static_cast(notifySendData->GetDataType())), + return false); return true; } @@ -169,11 +209,11 @@ static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context OP_TILING_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); - fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); - fe::PlatFormInfos &platformInfo = *platformInfoPtr; - - std::string socVersion; - (void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion); + int tilingKey = TILING_KEY_INT; + if (CheckIfA2Machine(context)) { + tilingKey = tilingKey + TILING_KEY_A2_TYPE; + } + context->SetTilingKey(tilingKey); auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); uint32_t blockDim; diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.cpp b/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.cpp index 780ab72e..1afc1e38 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.cpp +++ b/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.cpp @@ -15,12 +15,14 @@ extern "C" { #endif aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(const aclTensor *topkIdx, int64_t numTokens, int64_t numRanks, - int64_t numExperts, int64_t numTopk, const aclTensor *numTokensPerRank, - const aclTensor *numTokensPerExpert, const aclTensor *isTokenInRank, + int64_t numExperts, int64_t numTopk, int64_t localRankSize, + const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert, + const aclTensor *isTokenInRank, const aclTensor *notifySendData, uint64_t *workspaceSize, aclOpExecutor **executor) { - return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, numTokensPerRank, - numTokensPerExpert, isTokenInRank, workspaceSize, executor); + return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, localRankSize, + numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, + workspaceSize, executor); } aclnnStatus aclnnDispatchLayout(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.h b/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.h index 3483c4f9..2560fa57 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.h +++ b/csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.h @@ -13,16 +13,18 @@ extern "C" { * numRanks : required * numExperts : required * numTopk : required + * localRankSize : required * numTokensPerRank : required * numTokensPerExpert : required * isTokenInRank : required + * notifySendData : required * workspaceSize : size of workspace(output). * executor : executor context(output). */ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayoutGetWorkspaceSize( const aclTensor *topkIdx, int64_t numTokens, int64_t numRanks, int64_t numExperts, int64_t numTopk, - const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert, const aclTensor *isTokenInRank, - uint64_t *workspaceSize, aclOpExecutor **executor); + int64_t localRankSize, const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert, + const aclTensor *isTokenInRank, const aclTensor *notifySendData, uint64_t *workspaceSize, aclOpExecutor **executor); /* function: aclnnDispatchLayout * workspace : workspace memory addr(input). diff --git a/csrc/deepep/ops/op_kernel/dispatch_layout.cpp b/csrc/deepep/ops/op_kernel/dispatch_layout.cpp index b0bbc2cf..ec6318e7 100644 --- a/csrc/deepep/ops/op_kernel/dispatch_layout.cpp +++ b/csrc/deepep/ops/op_kernel/dispatch_layout.cpp @@ -1,17 +1,29 @@ #include "kernel_operator.h" #include "dispatch_layout.h" +#include "dispatch_layout_a2.h" #include "dispatch_layout_tiling.h" +#define TILING_KEY_INT 23 +#define TILING_KEY_A2_INT 123 + extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, GM_ADDR isTokenInRank, - GM_ADDR workspace, GM_ADDR tiling) + GM_ADDR notifySendData, GM_ADDR workspace, GM_ADDR tiling) { REGISTER_TILING_DEFAULT(DispatchLayoutTilingData); GET_TILING_DATA_WITH_STRUCT(DispatchLayoutTilingData, tilingData, tiling); TPipe pipe; - DispatchLayout op; - op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, workspace, &pipe, &tilingData); - op.Process(); + if (TILING_KEY_IS(TILING_KEY_INT)) { + MoeDispatchLayout::DispatchLayout op; + op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, workspace, &pipe, + &tilingData); + op.Process(); + } else if (TILING_KEY_IS(TILING_KEY_A2_INT)) { + MoeDispatchLayoutA2::DispatchLayoutA2 op; + op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, workspace, &pipe, + &tilingData); + op.Process(); + } } diff --git a/csrc/deepep/ops/op_kernel/dispatch_layout.h b/csrc/deepep/ops/op_kernel/dispatch_layout.h index 597bf7e8..b67709f9 100644 --- a/csrc/deepep/ops/op_kernel/dispatch_layout.h +++ b/csrc/deepep/ops/op_kernel/dispatch_layout.h @@ -9,9 +9,7 @@ #include "sync_collectives.h" #include "moe_distribute_base.h" #include "dispatch_layout_tiling.h" - -using namespace AscendC; -using namespace Moe; +namespace MoeDispatchLayout { constexpr uint32_t UB_32_ALIGN = 32U; @@ -23,6 +21,8 @@ __aicore__ inline void SyncFunc() AscendC::WaitFlag(eventID); } +using namespace AscendC; +using namespace Moe; template class DispatchLayout { @@ -30,7 +30,7 @@ class DispatchLayout __aicore__ inline DispatchLayout(){}; __aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, - GM_ADDR isTokenInRank, GM_ADDR workspace, TPipe *pipe, + GM_ADDR isTokenInRank, GM_ADDR notifySendData, GM_ADDR workspace, TPipe *pipe, const DispatchLayoutTilingData *tilingData) { numTokens_ = tilingData->dispatchLayoutInfo.numTokens; @@ -42,6 +42,7 @@ class DispatchLayout coreIdx_ = GetBlockIdx(); uint32_t maxAivNum = GetBlockNum(); aivNum_ = numTokens_ <= maxAivNum ? numTokens_ : maxAivNum; + if (coreIdx_ >= aivNum_) { return; } @@ -157,5 +158,6 @@ class DispatchLayout uint32_t numTokensPerExpert32AlignIntLen_{0}; uint32_t isTokenInRank32AlignIntLen_{0}; }; +} // namespace MoeDispatchLayout #endif // DISPATCH_LAYOUT_H diff --git a/csrc/deepep/ops/op_kernel/dispatch_layout_a2.h b/csrc/deepep/ops/op_kernel/dispatch_layout_a2.h new file mode 100644 index 00000000..47510b85 --- /dev/null +++ b/csrc/deepep/ops/op_kernel/dispatch_layout_a2.h @@ -0,0 +1,335 @@ +#ifndef DISPATCH_LAYOUT_A2_H +#define DISPATCH_LAYOUT_A2_H + +#include +#include "kernel_operator.h" + +#include "comm_args.h" +#include "data_copy.h" +#include "sync_collectives.h" +#include "moe_distribute_base.h" +#include "dispatch_layout_tiling.h" + +namespace MoeDispatchLayoutA2 { + +constexpr uint32_t UB_32_ALIGN = 32U; +constexpr uint32_t MAX_BATCH_SIZE = 4096U; +constexpr uint32_t TEMP_BATCH_SIZE = 8U; + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +using namespace AscendC; +using namespace Moe; +template +class DispatchLayoutA2 +{ +public: + __aicore__ inline DispatchLayoutA2(){}; + + __aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, + GM_ADDR isTokenInRank, GM_ADDR notifySendData, GM_ADDR workspace, TPipe *pipe, + const DispatchLayoutTilingData *tilingData) + { + numTokens_ = tilingData->dispatchLayoutInfo.numTokens; + numRanks_ = tilingData->dispatchLayoutInfo.numRanks; + numExperts_ = tilingData->dispatchLayoutInfo.numExperts; + numTopk_ = tilingData->dispatchLayoutInfo.numTopk; + localRankSize_ = tilingData->dispatchLayoutInfo.localRankSize; + serverNum_ = numRanks_ / localRankSize_; + tpipe_ = pipe; + + coreIdx_ = GetBlockIdx(); + uint32_t maxAivNum = GetBlockNum() - 1; + aivNum_ = numTokens_ <= maxAivNum ? numTokens_ : maxAivNum; + uint32_t temp = numTokens_ / aivNum_; + uint32_t restNum = numTokens_ % aivNum_; + int64_t topkIdxOffset; + int64_t isTokenOffset; + int64_t serverOffsetOffset; + int64_t serverNumOffset; + tempTokens_ = temp; + + if (coreIdx_ < aivNum_) { + if (coreIdx_ < restNum) { + tempTokens_++; + } + topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN; + numTokensPerRank32AlignIntLen_ = Ceil(numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + numTokensPerExpert32AlignIntLen_ = Ceil(numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerOffset32AlignIntLen_ = + Ceil(tempTokens_ * serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerUniqCount32AlignIntLen_ = Ceil(serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerTotalCount32AlignIntLen_ = + Ceil(tempTokens_ * serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerNum32AlignIntLen_ = Ceil(tempTokens_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + + if (coreIdx_ < restNum) { + topkIdxOffset = coreIdx_ * tempTokens_ * numTopk_ * sizeof(int64_t); + isTokenOffset = coreIdx_ * tempTokens_ * numRanks_ * sizeof(T); + serverOffsetOffset = coreIdx_ * tempTokens_ * serverNum_ * sizeof(T); + serverNumOffset = coreIdx_ * tempTokens_ * sizeof(T); + } else { + topkIdxOffset = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(int64_t); + isTokenOffset = (restNum + coreIdx_ * tempTokens_) * numRanks_ * sizeof(T); + serverOffsetOffset = (restNum + coreIdx_ * tempTokens_) * serverNum_ * sizeof(T); + serverNumOffset = (restNum + coreIdx_ * tempTokens_) * sizeof(T); + } + + topkIdxGM_.SetGlobalBuffer((__gm__ int64_t *)(topkIdx + topkIdxOffset)); + numTokensPerRankGM_.SetGlobalBuffer((__gm__ T *)numTokensPerRank); + numTokensPerExpertSrcGM_.SetGlobalBuffer((__gm__ T *)numTokensPerExpert); + numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T *)notifySendData); + isTokenInRankGM_.SetGlobalBuffer((__gm__ T *)(isTokenInRank + isTokenOffset)); + localTokenServerUniqCountGM_.SetGlobalBuffer((__gm__ T *)(notifySendData) + numExperts_); + localTokenServerTotalCountGM_.SetGlobalBuffer((__gm__ T *)(notifySendData + serverOffsetOffset) + + numExperts_ + serverNum_); + localTokenServerNumGM_.SetGlobalBuffer((__gm__ T *)(notifySendData + serverNumOffset) + numExperts_ + + serverNum_ * (MAX_BATCH_SIZE + 1)); + localTokenServerOffsetGM_.SetGlobalBuffer((__gm__ T *)(notifySendData + serverOffsetOffset) + numExperts_ + + serverNum_ + MAX_BATCH_SIZE * (serverNum_ + 1)); + } + if (coreIdx_ == aivNum_) { + expertRankTokenIdx32AlignIntLen_ = + Ceil(numExperts_ * TEMP_BATCH_SIZE * 2 * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerOffset32AlignIntLen_ = Ceil(numTokens_ * serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + topkIdxGM_.SetGlobalBuffer((__gm__ int64_t *)topkIdx); + localTokenServerOffsetGM_.SetGlobalBuffer((__gm__ T *)notifySendData + numExperts_ + serverNum_ + + MAX_BATCH_SIZE * (serverNum_ + 1)); + sendTokenIdxGM_.SetGlobalBuffer((__gm__ T *)notifySendData + numExperts_ + serverNum_ + + MAX_BATCH_SIZE * (1 + 2 * serverNum_)); + expertRankTokenIdxGM_.SetGlobalBuffer((__gm__ T *)notifySendData + numExperts_ + serverNum_ + + MAX_BATCH_SIZE * (1 + 2 * serverNum_ + numTopk_)); + } + } + + __aicore__ inline void Process() + { + if (coreIdx_ < aivNum_) { + MultiCoreCompute(); + } + SyncAll(); + if (coreIdx_ == aivNum_) { + ComputeServerOffset(); + } + } + +private: + __aicore__ inline void MultiCoreCompute() + { + tpipe_->Reset(); + tpipe_->InitBuffer(topkIdxBuf_, topkIdx32AlignIntLen_); + tpipe_->InitBuffer(numTokensPerRankBuf_, numTokensPerRank32AlignIntLen_); + tpipe_->InitBuffer(numTokensPerExpertBuf_, numTokensPerExpert32AlignIntLen_); + tpipe_->InitBuffer(isTokenInRankBuf_, isTokenInRank32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerOffsetBuf_, localTokenServerOffset32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerUniqCountBuf_, localTokenServerUniqCount32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerTotalCountBuf_, localTokenServerTotalCount32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerNumBuf_, localTokenServerNum32AlignIntLen_); + tpipe_->InitBuffer(seenRankBuf_, numRanks_ * sizeof(T)); + tpipe_->InitBuffer(seenServerBuf_, serverNum_ * sizeof(T)); + LocalTensor topkIdxTensor = topkIdxBuf_.AllocTensor(); + const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U}; + const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; + DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams); + SyncFunc(); + LocalTensor numTokensPerRankTensor = numTokensPerRankBuf_.AllocTensor(); + LocalTensor numTokensPerExpertTensor = numTokensPerExpertBuf_.AllocTensor(); + LocalTensor isTokenInRankTensor = isTokenInRankBuf_.AllocTensor(); + LocalTensor localTokenServerOffsetTensor = localTokenServerOffsetBuf_.AllocTensor(); + LocalTensor localTokenServerUniqCountTensor = localTokenServerUniqCountBuf_.AllocTensor(); + LocalTensor localTokenServerTotalCountTensor = localTokenServerTotalCountBuf_.AllocTensor(); + LocalTensor localTokenServerNumTensor = localTokenServerNumBuf_.AllocTensor(); + LocalTensor seenRankTensor = seenRankBuf_.AllocTensor(); + LocalTensor seenServerTensor = seenServerBuf_.AllocTensor(); + Duplicate(numTokensPerRankTensor, 0, numRanks_); + Duplicate(numTokensPerExpertTensor, 0, numExperts_); + Duplicate(isTokenInRankTensor, 0, tempTokens_ * numRanks_); + Duplicate(localTokenServerOffsetTensor, -1, tempTokens_ * serverNum_); + Duplicate(localTokenServerUniqCountTensor, 0, serverNum_); + Duplicate(localTokenServerTotalCountTensor, 0, tempTokens_ * serverNum_); + Duplicate(localTokenServerNumTensor, 0, tempTokens_); + SyncFunc(); + int experts_per_rank = numExperts_ / numRanks_; + for (int i = 0; i < tempTokens_; ++i) { + SyncFunc(); + Duplicate(seenRankTensor, 0, numRanks_); + Duplicate(seenServerTensor, 0, serverNum_); + SyncFunc(); + for (int j = 0; j < numTopk_; ++j) { + int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); + uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1; + numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num); + int rank_id = expert_idx / experts_per_rank; + int server_id = rank_id / localRankSize_; + if (!seenServerTensor.GetValue(server_id)) { + localTokenServerOffsetTensor.SetValue(i * serverNum_ + server_id, 1); + uint32_t uniqCount = localTokenServerUniqCountTensor.GetValue(server_id); + localTokenServerUniqCountTensor.SetValue(server_id, uniqCount + 1); + seenServerTensor.SetValue(server_id, 1); + uint32_t sendServerNum = localTokenServerNumTensor.GetValue(i); + localTokenServerNumTensor.SetValue(i, sendServerNum + 1); + } + uint32_t totalCount = localTokenServerTotalCountTensor.GetValue(i * serverNum_ + server_id) + 1; + localTokenServerTotalCountTensor.SetValue(i * serverNum_ + server_id, totalCount); + if (!seenRankTensor.GetValue(rank_id)) { + uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1; + isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1); + seenRankTensor.SetValue(rank_id, 1); + numTokensPerRankTensor.SetValue(rank_id, per_rank_num); + } + } + } + uint32_t sendSize = tempTokens_ * numRanks_ * sizeof(T); + const DataCopyExtParams isTokenInRankDataCopyParams{1U, sendSize, 0U, 0U, 0U}; + sendSize = tempTokens_ * sizeof(T); + DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams); + const DataCopyExtParams localTokenServerNumParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerNumGM_, localTokenServerNumTensor, localTokenServerNumParams); + sendSize = tempTokens_ * serverNum_ * sizeof(T); + const DataCopyExtParams localTokenServerOffsetParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerOffsetGM_, localTokenServerOffsetTensor, localTokenServerOffsetParams); + const DataCopyExtParams localTokenServerTotalCountParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerTotalCountGM_, localTokenServerTotalCountTensor, localTokenServerTotalCountParams); + sendSize = serverNum_ * sizeof(T); + AscendC::SetAtomicAdd(); + const DataCopyExtParams localTokenServerUniqCountParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerUniqCountGM_, localTokenServerUniqCountTensor, localTokenServerUniqCountParams); + const DataCopyExtParams numTokensPerRankDataCopyParams{1U, numTokensPerRank32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams); + const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); + DataCopyPad(numTokensPerExpertSrcGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); + AscendC::SetAtomicNone(); + } + + __aicore__ inline void ComputeServerOffset() + { + tpipe_->Reset(); + tpipe_->InitBuffer(localTokenServerOffsetBuf_, localTokenServerOffset32AlignIntLen_); + tpipe_->InitBuffer(seenServerBuf_, serverNum_ * sizeof(T)); + tpipe_->InitBuffer(expertRankTokenIdxBuf_, expertRankTokenIdx32AlignIntLen_); + tpipe_->InitBuffer(countExpertBuf_, numExperts_ * sizeof(T)); + LocalTensor localTokenServerOffsetTensor = localTokenServerOffsetBuf_.AllocTensor(); + LocalTensor seenServerTensor = seenServerBuf_.AllocTensor(); + const DataCopyExtParams dataCopyParams{1U, localTokenServerOffset32AlignIntLen_, 0U, 0U, 0U}; + const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; + DataCopyPad(localTokenServerOffsetTensor, localTokenServerOffsetGM_, dataCopyParams, padParams); + SyncFunc(); + Duplicate(seenServerTensor, 0, serverNum_); + SyncFunc(); + for (int i = 0; i < numTokens_; i++) { + for (int j = 0; j < serverNum_; j++) { + int32_t value = localTokenServerOffsetTensor.GetValue(i * serverNum_ + j); + if (value > 0) { + int32_t offset = seenServerTensor.GetValue(j); + localTokenServerOffsetTensor.SetValue(i * serverNum_ + j, offset); + seenServerTensor.SetValue(j, offset + 1); + } + } + } + SyncFunc(); + DataCopyPad(localTokenServerOffsetGM_, localTokenServerOffsetTensor, dataCopyParams); + LocalTensor countExpertTensor = countExpertBuf_.AllocTensor(); + LocalTensor expertRankTokenIdxTensor = expertRankTokenIdxBuf_.AllocTensor(); + Duplicate(countExpertTensor, 0, numExperts_); + SyncFunc(); + int32_t experts_per_rank = numExperts_ / numRanks_; + for (int i = 0; i < numTokens_; i++) { + for (int j = 0; j < numTopk_; j++) { + int32_t expert_id = topkIdxGM_.GetValue(i * numTopk_ + j); + int32_t server_id = (expert_id / experts_per_rank) / localRankSize_; + int32_t offset = localTokenServerOffsetTensor.GetValue(i * serverNum_ + server_id); + int32_t count = countExpertTensor.GetValue(expert_id); + expertRankTokenIdxTensor.SetValue(expert_id * TEMP_BATCH_SIZE + count, offset); + expertRankTokenIdxTensor.SetValue((numExperts_ + expert_id) * TEMP_BATCH_SIZE + count, i); + sendTokenIdxGM_.SetValue(i * numTopk_ + j, count); + count++; + countExpertTensor.SetValue(expert_id, count); + if (count % TEMP_BATCH_SIZE == 0) { + SyncFunc(); + const DataCopyExtParams expertRankTokendataCopyParams{1U, TEMP_BATCH_SIZE * sizeof(T), 0U, 0U, 0U}; + DataCopyPad(expertRankTokenIdxGM_[expert_id * MAX_BATCH_SIZE + count - TEMP_BATCH_SIZE], + expertRankTokenIdxTensor[expert_id * TEMP_BATCH_SIZE], expertRankTokendataCopyParams); + DataCopyPad( + expertRankTokenIdxGM_[(numExperts_ + expert_id) * MAX_BATCH_SIZE + count - TEMP_BATCH_SIZE], + expertRankTokenIdxTensor[(numExperts_ + expert_id) * TEMP_BATCH_SIZE], + expertRankTokendataCopyParams); + SyncFunc(); + Duplicate(expertRankTokenIdxTensor[expert_id * TEMP_BATCH_SIZE], 0, TEMP_BATCH_SIZE); + Duplicate(expertRankTokenIdxTensor[(numExperts_ + expert_id) * TEMP_BATCH_SIZE], 0, + TEMP_BATCH_SIZE); + } + } + } + for (int i = 0; i < numExperts_; i++) { + int32_t count = countExpertTensor.GetValue(i); + uint32_t rest = count % TEMP_BATCH_SIZE; + if (rest) { + SyncFunc(); + const DataCopyExtParams expertRankTokendataCopyParams{1U, uint32_t(rest * sizeof(T)), 0U, 0U, 0U}; + DataCopyPad(expertRankTokenIdxGM_[i * MAX_BATCH_SIZE + count - rest], + expertRankTokenIdxTensor[i * TEMP_BATCH_SIZE], expertRankTokendataCopyParams); + DataCopyPad(expertRankTokenIdxGM_[(i + numExperts_) * MAX_BATCH_SIZE + count - rest], + expertRankTokenIdxTensor[(i + numExperts_) * TEMP_BATCH_SIZE], + expertRankTokendataCopyParams); + SyncFunc(); + } + } + } + + GlobalTensor topkIdxGM_; + GlobalTensor numTokensPerRankGM_; + GlobalTensor numTokensPerExpertGM_; + GlobalTensor numTokensPerExpertSrcGM_; + GlobalTensor isTokenInRankGM_; + GlobalTensor localTokenServerOffsetGM_; + GlobalTensor localTokenServerUniqCountGM_; + GlobalTensor localTokenServerTotalCountGM_; + GlobalTensor localTokenServerNumGM_; + GlobalTensor expertRankTokenIdxGM_; + GlobalTensor sendTokenIdxGM_; + + TBuf<> topkIdxBuf_; + TBuf<> numTokensPerRankBuf_; + TBuf<> numTokensPerExpertBuf_; + TBuf<> isTokenInRankBuf_; + TBuf<> localTokenServerOffsetBuf_; + TBuf<> localTokenServerUniqCountBuf_; + TBuf<> localTokenServerTotalCountBuf_; + TBuf<> localTokenServerNumBuf_; + TBuf<> seenRankBuf_; + TBuf<> seenServerBuf_; + TBuf<> countExpertBuf_; + TBuf<> expertRankTokenIdxBuf_; + + TPipe *tpipe_{nullptr}; + uint32_t numTokens_{0}; + uint32_t numRanks_{0}; + uint32_t numExperts_{0}; + uint32_t numTopk_{0}; + uint32_t localRankSize_{0}; + uint32_t serverNum_{0}; + uint32_t coreIdx_{0}; + uint32_t aivNum_{0}; + uint32_t tempTokens_{0}; + + uint32_t topkIdx32AlignIntLen_{0}; + uint32_t numTokensPerRank32AlignIntLen_{0}; + uint32_t numTokensPerExpert32AlignIntLen_{0}; + uint32_t isTokenInRank32AlignIntLen_{0}; + uint32_t localTokenServerOffset32AlignIntLen_{0}; + uint32_t localTokenServerUniqCount32AlignIntLen_{0}; + uint32_t localTokenServerTotalCount32AlignIntLen_{0}; + uint32_t localTokenServerNum32AlignIntLen_{0}; + uint32_t expertRankTokenIdx32AlignIntLen_{0}; +}; +} // namespace MoeDispatchLayoutA2 + +#endif // DISPATCH_LAYOUT_A2_H diff --git a/csrc/deepep/ops/op_kernel/dispatch_layout_tiling.h b/csrc/deepep/ops/op_kernel/dispatch_layout_tiling.h index bf56f45a..af1d0eae 100644 --- a/csrc/deepep/ops/op_kernel/dispatch_layout_tiling.h +++ b/csrc/deepep/ops/op_kernel/dispatch_layout_tiling.h @@ -8,6 +8,7 @@ struct DispatchLayoutInfo { uint32_t numRanks; uint32_t numExperts; uint32_t numTopk; + uint32_t localRankSize; uint64_t totalUbSize; };