Skip to content

Commit 11f8090

Browse files
author
linzihan
committed
integrate A2 and A3 interface
1 parent 85d9980 commit 11f8090

18 files changed

+173
-576
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ constexpr int PADDING_SIZE = 3;
1212
constexpr size_t HCOMM_NAME_LEN = 128;
1313
constexpr uint32_t NO_SCALES = 0;
1414
constexpr uint32_t DYNAMIC_SCALES = 2;
15+
// In a shared header
16+
constexpr int A2_LOCAL_RANK_SIZE = 8;
17+
constexpr int A2_MAX_BATCH_SIZE = 4096;
18+
constexpr int A2_EXPERT_DATA_SIZE = 1 + 2 * A2_MAX_BATCH_SIZE; // 8193
1519

1620
Buffer::Buffer(int64_t rank, int64_t num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode,
1721
std::string moe_all_to_all_group_name)
@@ -46,7 +50,7 @@ bool Buffer::is_available() const
4650
return available;
4751
}
4852

49-
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
53+
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
5054
Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std::optional<EventHandle> &previous_event,
5155
bool async, bool allocate_on_comm_stream)
5256
{
@@ -73,30 +77,27 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
7377

7478
const int num_tokens = new_topk_idx.size(0);
7579
const int num_topk = new_topk_idx.size(1);
76-
const int local_ranksize = 8;
80+
const int local_ranksize = A2_LOCAL_RANK_SIZE;
7781
auto server_num = num_ranks / local_ranksize;
7882

7983
auto device = new_topk_idx.device();
8084
auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device));
8185
auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device));
8286
auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));
83-
auto local_token_server_offset = at::zeros({num_tokens * server_num}, at::dtype(at::kInt).device(device));
84-
auto local_token_server_uniq_count = at::zeros({server_num}, at::dtype(at::kInt).device(device));
85-
auto local_token_server_total_count = at::zeros({num_tokens * server_num}, at::dtype(at::kInt).device(device));
86-
auto local_token_server_num = at::zeros({num_tokens}, at::dtype(at::kInt).device(device));
87-
const int total_size = num_experts * 8193 + server_num + num_tokens * (1 + 2 * server_num + num_topk);
88-
auto expert_rank_token_idx = at::zeros({total_size}, at::dtype(at::kInt).device(device));
87+
const int total_size =
88+
num_experts * A2_EXPERT_DATA_SIZE + server_num + num_tokens * (1 + 2 * server_num + num_topk);
89+
auto total_data = at::zeros({total_size}, at::dtype(at::kInt).device(device));
8990

90-
EXEC_NPU_CMD(aclnnDispatchLayoutA2, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize, num_tokens_per_rank,
91-
num_tokens_per_expert, is_token_in_rank, local_token_server_offset, local_token_server_uniq_count,
92-
local_token_server_total_count, local_token_server_num, expert_rank_token_idx);
91+
EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize,
92+
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, total_data);
9393

94+
this->send_data = total_data;
9495
std::optional<torch::Tensor> num_tokens_per_rdma_rank = std::nullopt;
9596
std::optional<EventHandle> output_event = std::nullopt;
9697
auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool);
9798

9899
return std::make_tuple(num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank_bool,
99-
expert_rank_token_idx, output_event);
100+
output_event);
100101
}
101102

102103
std::tuple<at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>,

csrc/deepep/deep_ep.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct Buffer {
2626
at::Tensor ori_x;
2727
at::Tensor new_topk_idx;
2828
at::Tensor new_scales;
29+
at::Tensor send_data;
2930

3031
int64_t shared_expert_rank_num;
3132
int64_t shared_expert_num = 1;
@@ -47,7 +48,7 @@ struct Buffer {
4748

4849
bool is_available() const;
4950

50-
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
51+
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
5152
get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std::optional<EventHandle> &previous_event,
5253
bool async, bool allocate_on_comm_stream);
5354

csrc/deepep/ops/CMakePresets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
},
2828
"ASCEND_COMPUTE_UNIT": {
2929
"type": "STRING",
30-
"value": "ascend910b"
30+
"value": "ascend910_93"
3131
},
3232
"ENABLE_TEST": {
3333
"type": "BOOL",

csrc/deepep/ops/op_host/dispatch_layout.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class DispatchLayout : public OpDef
1616
this->Attr("num_ranks").Int();
1717
this->Attr("num_experts").Int();
1818
this->Attr("num_topk").Int();
19+
this->Attr("local_ranksize").Int();
1920

2021
this->Output("numTokensPerRank")
2122
.ParamType(REQUIRED)
@@ -32,9 +33,14 @@ class DispatchLayout : public OpDef
3233
.DataType({ge::DT_INT32})
3334
.Format({ge::FORMAT_ND})
3435
.UnknownShapeFormat({ge::FORMAT_ND});
36+
this->Output("totalData")
37+
.ParamType(REQUIRED)
38+
.DataType({ge::DT_INT32})
39+
.Format({ge::FORMAT_ND})
40+
.UnknownShapeFormat({ge::FORMAT_ND});
3541

36-
OpAICoreConfig aicore_config;
37-
aicore_config.DynamicCompileStaticFlag(true)
42+
OpAICoreConfig a3_config;
43+
a3_config.DynamicCompileStaticFlag(true)
3844
.DynamicFormatFlag(true)
3945
.DynamicRankSupportFlag(true)
4046
.DynamicShapeSupportFlag(true)
@@ -44,7 +50,19 @@ class DispatchLayout : public OpDef
4450
.ExtendCfgInfo("jitCompile.flag", "static_true")
4551
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
4652

47-
this->AICore().AddConfig("ascend910_93", aicore_config);
53+
OpAICoreConfig a2_config;
54+
a2_config.DynamicCompileStaticFlag(true)
55+
.DynamicFormatFlag(true)
56+
.DynamicRankSupportFlag(true)
57+
.DynamicShapeSupportFlag(true)
58+
.NeedCheckSupportFlag(false)
59+
.PrecisionReduceFlag(true)
60+
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
61+
.ExtendCfgInfo("jitCompile.flag", "static_false")
62+
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
63+
64+
this->AICore().AddConfig("ascend910_93", a3_config);
65+
this->AICore().AddConfig("ascend910b", a2_config);
4866
}
4967
};
5068

csrc/deepep/ops/op_host/dispatch_layout_a2.cpp

Lines changed: 0 additions & 78 deletions
This file was deleted.

csrc/deepep/ops/op_host/dispatch_layout_tiling.cc

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,23 @@ constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0;
2626
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0;
2727
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1;
2828
constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2;
29+
constexpr uint32_t OUTPUT_TOTAL_DATA_INDEX = 3;
2930

3031
constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0;
3132
constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1;
3233
constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2;
3334
constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3;
35+
constexpr uint32_t ATTR_LOCAL_RANKSIZE_INDEX = 4;
3436
const int64_t MAX_COMM_WORLD_SIZE = 384;
3537
const int64_t MAX_MOE_EXPERTS_NUM = 512;
38+
const int64_t MAX_A2_LOCAL_RANKSIZE = 8;
3639
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
3740
constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024;
3841
constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024;
3942

43+
constexpr static int TILING_KEY_INT = 23;
44+
constexpr static int TILING_KEY_A2_TYPE = 100;
45+
4046
constexpr uint32_t TWO_DIMS = 2;
4147
constexpr uint32_t K_MAX = 16;
4248
} // namespace
@@ -48,9 +54,24 @@ static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData &
4854
OP_LOGD(nodeName, "numRanks is %u.", tilingData.dispatchLayoutInfo.numRanks);
4955
OP_LOGD(nodeName, "numExperts is %u.", tilingData.dispatchLayoutInfo.numExperts);
5056
OP_LOGD(nodeName, "numTopk is %u.", tilingData.dispatchLayoutInfo.numTopk);
57+
OP_LOGD(nodeName, "localRankSize is %u.", tilingData.dispatchLayoutInfo.localRankSize);
5158
OP_LOGD(nodeName, "totalUbSize is %lu.", tilingData.dispatchLayoutInfo.totalUbSize);
5259
}
5360

61+
static bool CheckIfA2Machine(gert::TilingContext *context)
62+
{
63+
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
64+
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
65+
66+
std::string socVersion;
67+
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);
68+
69+
if (socVersion == "Ascend910B") {
70+
return true;
71+
}
72+
return false;
73+
}
74+
5475
static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
5576
DispatchLayoutTilingData &tilingData)
5677
{
@@ -61,11 +82,14 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
6182
auto numRanksPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_RANKS_INDEX));
6283
auto numExpertsPtr = attrs->GetAttrPointer<int64_t>(ATTR_NUM_EXPERTS_INDEX);
6384
auto numTopkPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_TOPK_INDEX));
85+
auto localRankSizePtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_LOCAL_RANKSIZE_INDEX));
6486

6587
OP_TILING_CHECK(numTokensPtr == nullptr, OP_LOGE(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED);
6688
OP_TILING_CHECK(numRanksPtr == nullptr, OP_LOGE(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED);
6789
OP_TILING_CHECK(numExpertsPtr == nullptr, OP_LOGE(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED);
6890
OP_TILING_CHECK(numTopkPtr == nullptr, OP_LOGE(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED);
91+
OP_TILING_CHECK(localRankSizePtr == nullptr, OP_LOGE(nodeName, "localRankSizePtr is null."),
92+
return ge::GRAPH_FAILED);
6993

7094
OP_TILING_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE),
7195
OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.",
@@ -80,10 +104,19 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
80104
OP_LOGE(nodeName, "numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld.", K_MAX, *numTopkPtr),
81105
return ge::GRAPH_FAILED);
82106

107+
if (CheckIfA2Machine(context)) {
108+
OP_TILING_CHECK(
109+
(*localRankSizePtr <= 0) || (*localRankSizePtr > MAX_A2_LOCAL_RANKSIZE),
110+
OP_LOGE(nodeName, "localRankSizePtr is invalid, only support (0, %ld], but got localRankSize=%ld.",
111+
MAX_A2_LOCAL_RANKSIZE, *localRankSizePtr),
112+
return ge::GRAPH_FAILED);
113+
}
114+
83115
tilingData.dispatchLayoutInfo.numTokens = static_cast<uint32_t>(*numTokensPtr);
84116
tilingData.dispatchLayoutInfo.numRanks = static_cast<uint32_t>(*numRanksPtr);
85117
tilingData.dispatchLayoutInfo.numExperts = static_cast<uint32_t>(*numExpertsPtr);
86118
tilingData.dispatchLayoutInfo.numTopk = static_cast<uint32_t>(*numTopkPtr);
119+
tilingData.dispatchLayoutInfo.localRankSize = static_cast<uint32_t>(*localRankSizePtr);
87120

88121
return ge::GRAPH_SUCCESS;
89122
}
@@ -102,11 +135,13 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
102135
auto numTokensPerRank = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_RANK_INDEX);
103136
auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX);
104137
auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX);
138+
auto totalData = context->GetOutputDesc(OUTPUT_TOTAL_DATA_INDEX);
105139

106140
OP_TILING_CHECK(topkIdx == nullptr, OP_LOGE(nodeName, "topkIdx is null."), return false);
107141
OP_TILING_CHECK(numTokensPerRank == nullptr, OP_LOGE(nodeName, "numTokensPerRank is null."), return false);
108142
OP_TILING_CHECK(numTokensPerExpert == nullptr, OP_LOGE(nodeName, "numTokensPerExpert is null."), return false);
109143
OP_TILING_CHECK(isTokenInRank == nullptr, OP_LOGE(nodeName, "isTokenInRank is null."), return false);
144+
OP_TILING_CHECK(totalData == nullptr, OP_LOGE(nodeName, "totalData is null."), return false);
110145

111146
OP_TILING_CHECK((topkIdx->GetDataType() != ge::DT_INT64),
112147
OP_LOGE(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.",
@@ -124,6 +159,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
124159
OP_LOGE(nodeName, "isTokenInRank datatype is invalid, datatype should be int, but is %d.",
125160
static_cast<ge::DataType>(isTokenInRank->GetDataType())),
126161
return false);
162+
OP_TILING_CHECK((totalData->GetDataType() != ge::DT_INT32),
163+
OP_LOGE(nodeName, "totalData datatype is invalid, datatype should be int, but is %d.",
164+
static_cast<ge::DataType>(totalData->GetDataType())),
165+
return false);
127166

128167
return true;
129168
}
@@ -169,11 +208,11 @@ static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context
169208
OP_TILING_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
170209
OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
171210

172-
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
173-
fe::PlatFormInfos &platformInfo = *platformInfoPtr;
174-
175-
std::string socVersion;
176-
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);
211+
int tilingKey = TILING_KEY_INT;
212+
if (CheckIfA2Machine(context)) {
213+
tilingKey = tilingKey + TILING_KEY_A2_TYPE;
214+
}
215+
context->SetTilingKey(tilingKey);
177216

178217
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
179218
uint32_t blockDim;

0 commit comments

Comments
 (0)