@@ -26,17 +26,23 @@ constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0;
2626constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0 ;
2727constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1 ;
2828constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2 ;
29+ constexpr uint32_t OUTPUT_TOTAL_DATA_INDEX = 3 ;
2930
3031constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0 ;
3132constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1 ;
3233constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2 ;
3334constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3 ;
35+ constexpr uint32_t ATTR_LOCAL_RANKSIZE_INDEX = 4 ;
3436const int64_t MAX_COMM_WORLD_SIZE = 384 ;
3537const int64_t MAX_MOE_EXPERTS_NUM = 512 ;
38+ const int64_t MAX_A2_LOCAL_RANKSIZE = 8 ;
3639constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024 ;
3740constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024 ;
3841constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024 ;
3942
43+ constexpr static int TILING_KEY_INT = 23 ;
44+ constexpr static int TILING_KEY_A2_TYPE = 100 ;
45+
4046constexpr uint32_t TWO_DIMS = 2 ;
4147constexpr uint32_t K_MAX = 16 ;
4248} // namespace
@@ -48,9 +54,24 @@ static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData &
4854 OP_LOGD (nodeName, " numRanks is %u." , tilingData.dispatchLayoutInfo .numRanks );
4955 OP_LOGD (nodeName, " numExperts is %u." , tilingData.dispatchLayoutInfo .numExperts );
5056 OP_LOGD (nodeName, " numTopk is %u." , tilingData.dispatchLayoutInfo .numTopk );
57+ OP_LOGD (nodeName, " localRankSize is %u." , tilingData.dispatchLayoutInfo .localRankSize );
5158 OP_LOGD (nodeName, " totalUbSize is %lu." , tilingData.dispatchLayoutInfo .totalUbSize );
5259}
5360
61+ static bool CheckIfA2Machine (gert::TilingContext *context)
62+ {
63+ fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo ();
64+ fe::PlatFormInfos &platformInfo = *platformInfoPtr;
65+
66+ std::string socVersion;
67+ (void )platformInfo.GetPlatformResWithLock (" version" , " Short_SoC_version" , socVersion);
68+
69+ if (socVersion == " Ascend910B" ) {
70+ return true ;
71+ }
72+ return false ;
73+ }
74+
5475static ge::graphStatus GetAttrAndSetTilingData (gert::TilingContext *context, const char *nodeName,
5576 DispatchLayoutTilingData &tilingData)
5677{
@@ -61,11 +82,14 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
6182 auto numRanksPtr = attrs->GetAttrPointer <int64_t >(static_cast <int >(ATTR_NUM_RANKS_INDEX));
6283 auto numExpertsPtr = attrs->GetAttrPointer <int64_t >(ATTR_NUM_EXPERTS_INDEX);
6384 auto numTopkPtr = attrs->GetAttrPointer <int64_t >(static_cast <int >(ATTR_NUM_TOPK_INDEX));
85+ auto localRankSizePtr = attrs->GetAttrPointer <int64_t >(static_cast <int >(ATTR_LOCAL_RANKSIZE_INDEX));
6486
6587 OP_TILING_CHECK (numTokensPtr == nullptr , OP_LOGE (nodeName, " numTokensPtr is null." ), return ge::GRAPH_FAILED);
6688 OP_TILING_CHECK (numRanksPtr == nullptr , OP_LOGE (nodeName, " numRanksPtr is null." ), return ge::GRAPH_FAILED);
6789 OP_TILING_CHECK (numExpertsPtr == nullptr , OP_LOGE (nodeName, " numExpertsPtr is null." ), return ge::GRAPH_FAILED);
6890 OP_TILING_CHECK (numTopkPtr == nullptr , OP_LOGE (nodeName, " numTopkPtr is null." ), return ge::GRAPH_FAILED);
91+ OP_TILING_CHECK (localRankSizePtr == nullptr , OP_LOGE (nodeName, " localRankSizePtr is null." ),
92+ return ge::GRAPH_FAILED);
6993
7094 OP_TILING_CHECK ((*numRanksPtr <= 0 ) || (*numRanksPtr > MAX_COMM_WORLD_SIZE),
7195 OP_LOGE (nodeName, " rankSize is invalid, only support (0, %ld], but got rankSize=%ld." ,
@@ -80,10 +104,19 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
80104 OP_LOGE (nodeName, " numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld." , K_MAX, *numTopkPtr),
81105 return ge::GRAPH_FAILED);
82106
107+ if (CheckIfA2Machine (context)) {
108+ OP_TILING_CHECK (
109+ (*localRankSizePtr <= 0 ) || (*localRankSizePtr > MAX_A2_LOCAL_RANKSIZE),
110+ OP_LOGE (nodeName, " localRankSizePtr is invalid, only support (0, %ld], but got localRankSize=%ld." ,
111+ MAX_A2_LOCAL_RANKSIZE, *localRankSizePtr),
112+ return ge::GRAPH_FAILED);
113+ }
114+
83115 tilingData.dispatchLayoutInfo .numTokens = static_cast <uint32_t >(*numTokensPtr);
84116 tilingData.dispatchLayoutInfo .numRanks = static_cast <uint32_t >(*numRanksPtr);
85117 tilingData.dispatchLayoutInfo .numExperts = static_cast <uint32_t >(*numExpertsPtr);
86118 tilingData.dispatchLayoutInfo .numTopk = static_cast <uint32_t >(*numTopkPtr);
119+ tilingData.dispatchLayoutInfo .localRankSize = static_cast <uint32_t >(*localRankSizePtr);
87120
88121 return ge::GRAPH_SUCCESS;
89122}
@@ -102,11 +135,13 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
102135 auto numTokensPerRank = context->GetOutputDesc (OUTPUT_NUM_TOKEN_PER_RANK_INDEX);
103136 auto numTokensPerExpert = context->GetOutputDesc (OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX);
104137 auto isTokenInRank = context->GetOutputDesc (OUTPUT_IS_TOKEN_IN_RANK_INDEX);
138+ auto totalData = context->GetOutputDesc (OUTPUT_TOTAL_DATA_INDEX);
105139
106140 OP_TILING_CHECK (topkIdx == nullptr , OP_LOGE (nodeName, " topkIdx is null." ), return false );
107141 OP_TILING_CHECK (numTokensPerRank == nullptr , OP_LOGE (nodeName, " numTokensPerRank is null." ), return false );
108142 OP_TILING_CHECK (numTokensPerExpert == nullptr , OP_LOGE (nodeName, " numTokensPerExpert is null." ), return false );
109143 OP_TILING_CHECK (isTokenInRank == nullptr , OP_LOGE (nodeName, " isTokenInRank is null." ), return false );
144+ OP_TILING_CHECK (totalData == nullptr , OP_LOGE (nodeName, " totalData is null." ), return false );
110145
111146 OP_TILING_CHECK ((topkIdx->GetDataType () != ge::DT_INT64),
112147 OP_LOGE (nodeName, " topkIdx datatype is invalid, datatype should be int, but is %d." ,
@@ -124,6 +159,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
124159 OP_LOGE (nodeName, " isTokenInRank datatype is invalid, datatype should be int, but is %d." ,
125160 static_cast <ge::DataType>(isTokenInRank->GetDataType ())),
126161 return false );
162+ OP_TILING_CHECK ((totalData->GetDataType () != ge::DT_INT32),
163+ OP_LOGE (nodeName, " totalData datatype is invalid, datatype should be int, but is %d." ,
164+ static_cast <ge::DataType>(totalData->GetDataType ())),
165+ return false );
127166
128167 return true ;
129168}
@@ -169,11 +208,11 @@ static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context
169208 OP_TILING_CHECK (SetWorkSpace (context, nodeName) != ge::GRAPH_SUCCESS,
170209 OP_LOGE (nodeName, " Tiling set workspace failed." ), return ge::GRAPH_FAILED);
171210
172- fe::PlatFormInfos *platformInfoPtr = context-> GetPlatformInfo () ;
173- fe::PlatFormInfos &platformInfo = *platformInfoPtr;
174-
175- std::string socVersion;
176- ( void )platformInfo. GetPlatformResWithLock ( " version " , " Short_SoC_version " , socVersion );
211+ int tilingKey = TILING_KEY_INT ;
212+ if ( CheckIfA2Machine (context)) {
213+ tilingKey = tilingKey + TILING_KEY_A2_TYPE;
214+ }
215+ context-> SetTilingKey (tilingKey );
177216
178217 auto ascendcPlatform = platform_ascendc::PlatformAscendC (context->GetPlatformInfo ());
179218 uint32_t blockDim;
0 commit comments