Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/deepep/ops/op_host/notify_dispatch_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
// Verify the size of the win area
NotifyDispatchTilingData *tilingData = context->GetTilingData<NotifyDispatchTilingData>();
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
uint64_t actualSize = dataSize * tilingData->notifyDispatchInfo.sendCount;
uint64_t actualSize = dataSize * tilingData->notifyDispatchInfo.sendCount + 2 * 1024 * 1024; // 2MB flag位
if (actualSize > maxWindowSize) {
OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %lu", actualSize);
OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %luMB.", actualSize / MB_SIZE);
return false;
}
return true;
Expand Down
4 changes: 0 additions & 4 deletions csrc/deepep/ops/op_kernel/notify_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ extern "C" __global__ __aicore__ void notify_dispatch(GM_ADDR sendData, GM_ADDR
REGISTER_TILING_DEFAULT(NotifyDispatchTilingData);
GET_TILING_DATA_WITH_STRUCT(NotifyDispatchTilingData, tilingData, tiling);

// hcomm will set magic later in init
uint32_t magic = 1;
GM_ADDR commArgs = nullptr;

int localRank = tilingData.notifyDispatchInfo.localRankId;
int localRankSize = tilingData.notifyDispatchInfo.localRankSize;
int rank = tilingData.notifyDispatchInfo.rankId;
Expand Down
172 changes: 88 additions & 84 deletions csrc/deepep/ops/op_kernel/notify_dispatch.h

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions csrc/deepep/ops2/op_host/notify_dispatch_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
// Verify the size of the win area
NotifyDispatchTilingData *tilingData = context->GetTilingData<NotifyDispatchTilingData>();
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
uint64_t actualSize = dataSize * tilingData->notifyDispatchInfo.sendCount;
uint64_t actualSize = dataSize * tilingData->notifyDispatchInfo.sendCount + 2 * 1024 * 1024; // 2MB flag位
if (actualSize > maxWindowSize) {
OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %lu", actualSize);
OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %luMB.", actualSize / MB_SIZE);
return false;
}
return true;
Expand Down
29 changes: 19 additions & 10 deletions csrc/deepep/ops2/op_host/notify_dispatch_tiling_a2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ constexpr uint32_t ATTR_RANK_ID_INDEX = 6;
constexpr uint32_t ATTR_LOCAL_RANK_SIZE_INDEX = 7;
constexpr uint32_t ATTR_LOCAL_RANK_ID_INDEX = 8;

const size_t MAX_GROUP_NAME_LENGTH = 128UL;
const int64_t MAX_COMM_WORLD_SIZE = 384;
constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL;
constexpr int64_t MAX_COMM_WORLD_SIZE = 384;
constexpr int64_t MAX_A2_WORLD_SIZE = 64;
constexpr int64_t MAX_COMM_LOCAL_SIZE = 16;
constexpr int64_t MAX_A2_LOCAL_SIZE = 8;

constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024;
Expand All @@ -93,8 +96,6 @@ constexpr static int TILING_KEY_BFLOAT16 = 21;
constexpr static int TILING_KEY_FLOAT = 22;
constexpr static int TILING_KEY_INT = 23;
constexpr static int TILING_KEY_A2_TYPE = 100;

constexpr static int ALL_TO_ALL_CORE_NUM = 32;
} // namespace

namespace optiling {
Expand Down Expand Up @@ -141,14 +142,23 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
return ge::GRAPH_FAILED);
OP_TILING_CHECK(localRankIdPtr == nullptr, OP_LOGE(nodeName, "localRankIdPtr is null."), return ge::GRAPH_FAILED);

OP_TILING_CHECK((*rankSizePtr <= 0) || (*rankSizePtr > MAX_COMM_WORLD_SIZE),
OP_TILING_CHECK((*rankSizePtr <= 0) || (*rankSizePtr > MAX_A2_WORLD_SIZE),
OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.",
MAX_COMM_WORLD_SIZE, *rankSizePtr),
MAX_A2_WORLD_SIZE, *rankSizePtr),
return ge::GRAPH_FAILED);
OP_TILING_CHECK(
(*rankIdPtr < 0) || (*rankIdPtr >= *rankSizePtr),
OP_LOGE(nodeName, "rankId is invalid, only support [0, %ld), but got rankId=%ld.", *rankSizePtr, *rankIdPtr),
return ge::GRAPH_FAILED);
OP_TILING_CHECK((*localRankSizePtr <= 0) || (*localRankSizePtr > MAX_A2_LOCAL_SIZE),
OP_LOGE(nodeName, "localRankSize is invalid, A2 only support (0, %ld], but got localRankSize=%ld.",
MAX_A2_LOCAL_SIZE, *localRankSizePtr),
return ge::GRAPH_FAILED);
OP_TILING_CHECK((*localRankIdPtr < 0) || (*localRankIdPtr >= *localRankSizePtr),
OP_LOGE(nodeName, "localRankId is invalid, only support [0, %ld), but got localRankId=%ld.",
*localRankSizePtr, *localRankIdPtr),
return ge::GRAPH_FAILED);

OP_TILING_CHECK((*sendCountPtr <= 0),
OP_LOGE(nodeName, "sendCount is invalid, only support > 0, but got sendCount=%ld.", *sendCountPtr),
return ge::GRAPH_FAILED);
Expand Down Expand Up @@ -187,8 +197,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
{
size_t *workSpaces = context->GetWorkspaceSizes(1);
OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE +
KERNEL_A2_ARG_SIZE; // TODO: 多预留空间,dispatch和combine同步要改?
workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE;
return ge::GRAPH_SUCCESS;
}

Expand Down Expand Up @@ -353,9 +362,9 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
// Verify the size of the win area
NotifyDispatchA2TilingData *tilingData = context->GetTilingData<NotifyDispatchA2TilingData>();
uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize();
uint64_t actualSize = dataSize * tilingData->notifyDispatchInfoA2.sendCount;
uint64_t actualSize = 2 * dataSize * tilingData->notifyDispatchInfoA2.sendCount + 2 * 1024 * 1024; // 2MB flag位
if (actualSize > maxWindowSize) {
OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %lu", actualSize);
OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %luMB", actualSize / MB_SIZE);
return false;
}
return true;
Expand Down
4 changes: 0 additions & 4 deletions csrc/deepep/ops2/op_kernel/notify_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ extern "C" __global__ __aicore__ void notify_dispatch(GM_ADDR sendData, GM_ADDR
REGISTER_TILING_DEFAULT(NotifyDispatchTilingData);
GET_TILING_DATA_WITH_STRUCT(NotifyDispatchTilingData, tilingData, tilingGM);

// hcomm will set magic later in init
uint32_t magic = 1;
GM_ADDR commArgs = nullptr;

int localRank = tilingData.notifyDispatchInfo.localRankId;
int localRankSize = tilingData.notifyDispatchInfo.localRankSize;
int rank = tilingData.notifyDispatchInfo.rankId;
Expand Down
Loading