From c36db0470da868efa5d8794eb81176b02b095adf Mon Sep 17 00:00:00 2001 From: GuoRen868 <1269192170@qq.com> Date: Sat, 8 Nov 2025 18:18:51 +0800 Subject: [PATCH 1/7] fused --- .../fused_deep_moe/op_host/fused_deep_moe.cpp | 82 + .../op_host/fused_deep_moe_infer.cpp | 93 + .../op_host/fused_deep_moe_tiling.cpp | 331 +++ .../op_kernel/fused_deep_moe.cpp | 33 + .../fused_deep_moe/op_kernel/fused_deep_moe.h | 447 ++++ .../op_kernel/fused_deep_moe_base.h | 17 + .../op_kernel/fused_deep_moe_tiling.h | 73 + csrc/utils/op_host/error_log.h | 43 + .../op_kernel/a3/cam_moe_distribute_combine.h | 812 +++++++ .../a3/cam_moe_distribute_dispatch.h | 1090 +++++++++ .../op_kernel/operator/catlass/act/act.hpp | 37 + .../operator/catlass/act/arch/arch.hpp | 54 + .../catlass/act/arch/cross_core_sync.hpp | 115 + .../catlass/act/arch/local_tensor_buffer.hpp | 231 ++ .../operator/catlass/act/arch/resource.hpp | 44 + .../op_kernel/operator/catlass/act/coord.hpp | 311 +++ .../operator/catlass/act/detail/alignment.hpp | 57 + .../operator/catlass/act/detail/callback.hpp | 63 + .../catlass/act/detail/dependent_false.hpp | 22 + .../operator/catlass/act/detail/macros.hpp | 20 + .../catlass/act/detail/tag_to_layout.hpp | 80 + .../act/epilogue/block/block_epilogue.hpp | 29 + .../block_epilogue_per_token_dequant.hpp | 763 +++++++ .../catlass/act/epilogue/dispatch_policy.hpp | 76 + .../act/epilogue/tile/copy_gm_to_ub.hpp | 156 ++ .../act/epilogue/tile/copy_ub_to_gm.hpp | 115 + .../tile/tile_broadcast_inplace_by_column.hpp | 64 + .../tile/tile_broadcast_inplace_by_row.hpp | 57 + .../act/epilogue/tile/tile_broadcast_mul.hpp | 122 + .../epilogue/tile/tile_broadcast_one_blk.hpp | 51 + .../catlass/act/epilogue/tile/tile_cast.hpp | 45 + .../catlass/act/epilogue/tile/tile_copy.hpp | 104 + .../act/epilogue/tile/tile_elemwise_add.hpp | 48 + .../act/epilogue/tile/tile_elemwise_mul.hpp | 47 + .../act/epilogue/tile/tile_elemwise_muls.hpp | 38 + .../act/epilogue/tile/tile_swizzle.hpp | 92 + .../catlass/act/gemm/block/block_mmad.hpp | 57 + ...block_mmad_preload_async_with_callback.hpp | 410 ++++ .../catlass/act/gemm/block/block_swizzle.hpp | 243 ++ .../catlass/act/gemm/dispatch_policy.hpp | 88 + .../operator/catlass/act/gemm/gemm_type.hpp | 29 + .../operator/catlass/act/gemm/helper.hpp | 280 +++ ...per_token_dequant_multistage_workspace.hpp | 358 +++ .../catlass/act/gemm/tile/copy_gm_to_l1.hpp | 798 +++++++ .../catlass/act/gemm/tile/copy_gm_to_ub.hpp | 53 + .../catlass/act/gemm/tile/copy_l0c_to_gm.hpp | 219 ++ .../catlass/act/gemm/tile/copy_l1_to_l0a.hpp | 392 ++++ .../catlass/act/gemm/tile/copy_l1_to_l0b.hpp | 537 +++++ .../catlass/act/gemm/tile/copy_ub_to_gm.hpp | 80 + .../catlass/act/gemm/tile/tile_copy.hpp | 183 ++ .../catlass/act/gemm/tile/tile_mmad.hpp | 110 + .../operator/catlass/act/gemm_coord.hpp | 159 ++ .../operator/catlass/act/gemv_coord.hpp | 107 + .../operator/catlass/act/layout/layout.hpp | 20 + .../operator/catlass/act/layout/matrix.hpp | 1184 ++++++++++ .../operator/catlass/act/layout/vector.hpp | 133 ++ .../operator/catlass/act/matrix_coord.hpp | 115 + .../operator/catlass/tla/int_tuple.hpp | 173 ++ .../op_kernel/operator/catlass/tla/layout.hpp | 371 +++ .../catlass/tla/numeric/integer_sequence.hpp | 70 + .../catlass/tla/numeric/integral_constant.hpp | 176 ++ .../operator/catlass/tla/numeric/math.hpp | 36 + .../op_kernel/operator/catlass/tla/tensor.hpp | 102 + .../op_kernel/operator/catlass/tla/tuple.hpp | 123 + .../operator/catlass/tla/type_traits.hpp | 45 + .../operator/epilogue/block/block_epilogue.h | 13 + .../block_epilogue_per_token_dequant_swiglu.h | 326 +++ .../operator/epilogue/dispatch_policy.h | 22 + .../epilogue/tile/tile_stride_binary.h | 107 + .../operator/epilogue/tile/tile_stride_muls.h | 59 + .../operator/gemm/block/block_mmad.h | 13 + ...d_preload_async_with_callback_resident_a.h | 420 ++++ .../op_kernel/operator/gemm/dispatch_policy.h | 28 + ...equant_swiglu_quant_multistage_workspace.h | 2023 +++++++++++++++++ 74 files changed, 15524 insertions(+) create mode 100644 csrc/fused_deep_moe/op_host/fused_deep_moe.cpp create mode 100644 csrc/fused_deep_moe/op_host/fused_deep_moe_infer.cpp create mode 100644 csrc/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp create mode 100644 csrc/fused_deep_moe/op_kernel/fused_deep_moe.cpp create mode 100644 csrc/fused_deep_moe/op_kernel/fused_deep_moe.h create mode 100644 csrc/fused_deep_moe/op_kernel/fused_deep_moe_base.h create mode 100644 csrc/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h create mode 100644 csrc/utils/op_host/error_log.h create mode 100644 csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h create mode 100644 csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h create mode 100644 csrc/utils/op_kernel/operator/catlass/act/act.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/arch/arch.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/arch/resource.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/coord.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/detail/alignment.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/detail/callback.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/detail/macros.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/helper.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemm_coord.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/gemv_coord.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/layout/layout.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/layout/matrix.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/layout/vector.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/act/matrix_coord.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/int_tuple.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/layout.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/numeric/math.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/tensor.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/tuple.hpp create mode 100644 csrc/utils/op_kernel/operator/catlass/tla/type_traits.hpp create mode 100644 csrc/utils/op_kernel/operator/epilogue/block/block_epilogue.h create mode 100644 csrc/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h create mode 100644 csrc/utils/op_kernel/operator/epilogue/dispatch_policy.h create mode 100644 csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h create mode 100644 csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h create mode 100644 csrc/utils/op_kernel/operator/gemm/block/block_mmad.h create mode 100644 csrc/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h create mode 100644 csrc/utils/op_kernel/operator/gemm/dispatch_policy.h create mode 100644 csrc/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h diff --git a/csrc/fused_deep_moe/op_host/fused_deep_moe.cpp b/csrc/fused_deep_moe/op_host/fused_deep_moe.cpp new file mode 100644 index 00000000000..50cd95f9ce0 --- /dev/null +++ b/csrc/fused_deep_moe/op_host/fused_deep_moe.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe operator definition file + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 create FusedDeepMoe operator definition file + */ +#include "register/op_def_registry.h" + +namespace ops { +class FusedDeepMoe : public OpDef +{ +public: + explicit FusedDeepMoe(const char *name) : OpDef(name) + { + this->Input("x") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("expert_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("gmm1_permuted_weight") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + this->Input("gmm1_permuted_weight_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("gmm2_weight") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + this->Input("gmm2_weight_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("expert_smooth_scales") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("expert_scales") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("output") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("ep_recv_count") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("group_ep").String(); + this->Attr("ep_rank_size").Int(); + this->Attr("ep_rank_id").Int(); + this->Attr("moe_expert_num").Int(); + this->Attr("share_expert_num").Int(); + this->Attr("share_expert_rank_num").Int(); + this->Attr("quant_mode").Int(); + this->Attr("global_bs").Int(); + + this->MC2().HcclGroup({"group_ep"}); + this->AICore().AddConfig("ascend910_93"); + } +}; + +OP_ADD(FusedDeepMoe); +} // namespace ops diff --git a/csrc/fused_deep_moe/op_host/fused_deep_moe_infer.cpp b/csrc/fused_deep_moe/op_host/fused_deep_moe_infer.cpp new file mode 100644 index 00000000000..1391b054393 --- /dev/null +++ b/csrc/fused_deep_moe/op_host/fused_deep_moe_infer.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe tiling function implementation file + * Author: Guo Ren + * Create: 2025-07-22 + * Note: + * History: 2025-07-13 create FusedDeepMoe infer function file + */ + +#include +#include "error_log.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" + +namespace ge { +constexpr uint32_t EXPAND_X_INDEX = 0; +constexpr uint32_t EXPERT_IDS_INDEX = 1; +constexpr uint32_t OUTPUT_X_INDEX = 0; +constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1; + +constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; +constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; +constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; +constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3; +constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4; +constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5; +constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6; +constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7; + +static ge::graphStatus InferShape(gert::InferShapeContext *context) +{ + const char *nodeName = context->GetNodeName(); + // infer output shape + const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX); + const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX); + gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX); + gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX); + if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr || + recvCountOutShape == nullptr) { + return GRAPH_FAILED; + } + if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) { + return GRAPH_FAILED; + } + + int bs = expertIdsShape->GetDim(0); + int h = expandXShape->GetDim(1); + + expandXOutShape->SetDimNum(expandXShape->GetDimNum()); + expandXOutShape->SetDim(0, bs); + expandXOutShape->SetDim(1, h); + + // infer recvCount shape + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto epRankSizePtr = attrs->GetAttrPointer(ATTR_EP_RANK_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARE_EXPERT_RANK_NUM_INDEX); + + OP_TILING_CHECK(epRankIdPtr == nullptr, OP_LOGE(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNumPtr == nullptr, OP_LOGE(nodeName, "moeExpertNumPtr is nullptr."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(epRankSizePtr == nullptr, OP_LOGE(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(sharedExpertRankNumPtr == nullptr, OP_LOGE(nodeName, "sharedExpertRankNumPtr is nullptr."), + return ge::GRAPH_FAILED); + uint32_t epRankSize = static_cast(*epRankSizePtr); + uint32_t moeExpertNum = static_cast(*moeExpertNumPtr); + uint32_t epRankId = static_cast(*epRankIdPtr); + uint32_t sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); + + recvCountOutShape->SetDimNum(1); + bool isShareExpert = (epRankId < sharedExpertRankNum); + if (isShareExpert) { + recvCountOutShape->SetDim(0, epRankSize); + } else { + recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum))); + } + + return GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) +{ + const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX); + context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType); + context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP(FusedDeepMoe).InferShape(InferShape).InferDataType(InferDataType); +} // namespace ge diff --git a/csrc/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp b/csrc/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp new file mode 100644 index 00000000000..c52ff4976a5 --- /dev/null +++ b/csrc/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp @@ -0,0 +1,331 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe tiling function implementation file + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 create FusedDeepMoe tiling function implementation file + */ +#include +#include +#include + +#include "error_log.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "../op_kernel/fused_deep_moe_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/hccl/hccl_tiling.h" + +#define GM_ALIGN_SIZE 512 +#define ENABLE_TILING_CHECK + +using namespace ge; +namespace { +constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8; +constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2; +constexpr uint32_t USE_CORE_NUM = 24; +constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024; +constexpr uint32_t CUBE_WORKSPACE_STAGE = 4; +constexpr uint32_t RESERVED_WORKSPACE_SIZE = 256 * 1024; + +constexpr uint32_t INPUT_X_INDEX = 0; +constexpr uint32_t INPUT_EXPERT_IDS_INDEX = 1; +constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2; +constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3; +constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4; +constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5; +constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6; +constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7; + +constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; +constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; +constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; +constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3; +constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4; +constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5; +constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6; +constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7; + +constexpr uint32_t MIN_BATCH_SIZE = 1; +constexpr uint32_t MAX_BATCH_SIZE = 256; +constexpr uint32_t MAX_MOE_EXERT_NUM = 512; +constexpr uint32_t RECV_AIV_NUM = 24; +constexpr uint32_t SUPPORT_TOP_K = 12; +constexpr uint32_t TWO_DIMS = 2; +constexpr uint32_t MIN_TOKEN_LENGTH = 512; +constexpr uint32_t MAX_TOKEN_LENGTH = 7168; +constexpr uint32_t MIN_GMM1_HIDDEN = 1024; +constexpr uint32_t MAX_GMM1_HIDDEN = 6144; +} // namespace + +namespace optiling { +static size_t CeilUp(size_t x, size_t y) +{ + return (x + y - 1) / y * y; +} + +static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName, + FusedDeepMoeTilingData &tilingData) +{ + uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + + uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank; + const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); + OP_TILING_CHECK(gmm1WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm1 weight shape is null."), + return ge::GRAPH_FAILED); + const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0); + OP_TILING_CHECK(gmm1WeightDim0 != localExpertNum, + OP_LOGE(nodeName, "gmm1Weight Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + + const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX); + OP_TILING_CHECK(gmm1WeightScaleStorageShape == nullptr, OP_LOGE(nodeName, "gmm1 weight scale shape is null."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.", + gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0); + OP_TILING_CHECK(gmm1WeightScaleDim0 != localExpertNum, + OP_LOGE(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1); + OP_TILING_CHECK(gmm1WeightScaleDim1 != gmm1WeightDim2, + OP_LOGE(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2), + return ge::GRAPH_FAILED); + + const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX); + OP_TILING_CHECK(gmm2WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm2 weight shape is null."), + return ge::GRAPH_FAILED); + const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0); + OP_TILING_CHECK(gmm2WeightDim0 != localExpertNum, + OP_LOGE(nodeName, "gmm2Weight Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + + const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX); + OP_TILING_CHECK(gmm2WeightScaleStorageShape == nullptr, OP_LOGE(nodeName, "gmm2 weight scale shape is null."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.", + gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0); + OP_TILING_CHECK(gmm2WeightScaleDim0 != localExpertNum, + OP_LOGE(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1); + OP_TILING_CHECK(gmm2WeightScaleDim1 != h, OP_LOGE(nodeName, "gmm2WeightScale Dim1 must be %u.", h), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckData(const char *nodeName, FusedDeepMoeTilingData &tilingData) +{ + uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; + OP_TILING_CHECK(batchSize < MIN_BATCH_SIZE, OP_LOGE(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(batchSize > MAX_BATCH_SIZE, OP_LOGE(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE), + return ge::GRAPH_FAILED); + uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; + OP_TILING_CHECK( + tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH, + OP_LOGE(nodeName, "tokenLength(h) is invalid. Only support [%u, %u].", MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH), + return ge::GRAPH_FAILED); + uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + OP_TILING_CHECK( + gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN, + OP_LOGE(nodeName, "gmm1 hidden size is invalid. Only support [%u, %u].", MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN), + return ge::GRAPH_FAILED); + uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k; + OP_TILING_CHECK(topK > SUPPORT_TOP_K, OP_LOGE(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K), + return ge::GRAPH_FAILED); + uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + if (globalBatchSize == 0) { + globalBatchSize = epRankSize * batchSize; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = globalBatchSize; + } else { + OP_TILING_CHECK(globalBatchSize < 0, OP_LOGE(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(globalBatchSize % epRankSize > 0, + OP_LOGE(nodeName, "globalBatchSize must be divisible by epRankSize."), return ge::GRAPH_FAILED); + } + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, + FusedDeepMoeTilingData &tilingData, std::string &groupEp) +{ + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto groupEpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); + auto epRankSizePtr = attrs->GetAttrPointer(ATTR_EP_RANK_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto sharedExpertNumPtr = attrs->GetAttrPointer(ATTR_SHARE_EXPERT_NUM_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARE_EXPERT_RANK_NUM_INDEX); + auto quantModePtr = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); + auto globalBsPtr = attrs->GetAttrPointer(ATTR_GLOBAL_BS_INDEX); + + uint32_t epRankSize = static_cast(*epRankSizePtr); + uint32_t epRankId = static_cast(*epRankIdPtr); + uint32_t moeExpertNum = static_cast(*moeExpertNumPtr); + uint32_t sharedExpertNum = static_cast(*sharedExpertNumPtr); + uint32_t sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); + uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum); + +#ifdef ENABLE_TILING_CHECK + OP_TILING_CHECK(epRankId < 0, OP_LOGE(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(epRankId >= epRankSize, OP_LOGE(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNum > MAX_MOE_EXERT_NUM, OP_LOGE(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNum <= 0, OP_LOGE(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(sharedExpertNum != 1, OP_LOGE(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0, + OP_LOGE(nodeName, "moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNumPerRank > RECV_AIV_NUM, + OP_LOGE(nodeName, "moeExpertNumPerRank must <= %d.", RECV_AIV_NUM), return ge::GRAPH_FAILED); +#endif + + groupEp = std::string(groupEpPtr); + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize = epRankSize; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId = epRankId; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum = moeExpertNum; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum = sharedExpertNum; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum = sharedExpertRankNum; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.quantMode = static_cast(*quantModePtr); + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = static_cast(*globalBsPtr); + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank = moeExpertNumPerRank; + return ge::GRAPH_SUCCESS; +} + +static void SetHcommCfg(const gert::TilingContext *context, FusedDeepMoeTilingData *tiling, const std::string groupEp) +{ + const char *nodeName = context->GetNodeName(); + OP_LOGD(nodeName, "FusedDeepMoe groupEp = %s", groupEp.c_str()); + uint32_t opType = OP_TYPE_ALL_TO_ALL; + std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; + std::string algConfigAllGatherStr = "AllGather=level0:ring"; + + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType, algConfigAllToAllStr); + mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling); +} + +static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName, + FusedDeepMoeTilingData &tilingData) +{ + size_t *workSpaces = context->GetWorkspaceSizes(1); + OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); + size_t maxTokenNum; + uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; + uint32_t globalBs = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + uint32_t maxBatchSize = globalBs / epRankSize; + uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k; + uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2; + if (epRankId < sharedExpertRankNum) { + maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum; + } else { + maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank); + } + + size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE); + size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); + size_t CVSwapBufferSize = + CeilUp(USE_CORE_NUM * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE); + size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE); + size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE); + size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE); + size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE); + size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE); + size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); + size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE); + size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE); + size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize + + epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize; + + workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize; + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus FusedDeepMoeTilingFuncImpl(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + FusedDeepMoeTilingData *tilingData = context->GetTilingData(); + OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); + std::string groupEp = ""; + + const gert::StorageShape *xStorageShape = context->GetInputShape(INPUT_X_INDEX); + OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "x shape is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "x shape dims must be 2, but current dim num is %lu.", + xStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t batchSize = xStorageShape->GetStorageShape().GetDim(0); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs = batchSize; + const int64_t hiddenSize = xStorageShape->GetStorageShape().GetDim(1); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h = hiddenSize; + + const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(INPUT_EXPERT_IDS_INDEX); + OP_TILING_CHECK(expertIdsStorageShape == nullptr, OP_LOGE(nodeName, "expertIds shape is null."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "expertIds shape dims must be 2, but current dim num is %lu.", + expertIdsStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t topK = expertIdsStorageShape->GetStorageShape().GetDim(1); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK; + OP_TILING_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED); + const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); + OP_TILING_CHECK(gmm1WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm1Weight shape is null."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS); +#ifdef ENABLE_TILING_CHECK + OP_TILING_CHECK(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OP_LOGE(nodeName, "CheckData failed."), + return ge::GRAPH_FAILED); +#endif + OP_TILING_CHECK(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); + SetHcommCfg(context, tilingData, groupEp); + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) { + context->SetTilingKey(0); + } else { + context->SetTilingKey(EXEC_FLAG_DEEP_FUSE); + } + context->SetBlockDim(USE_CORE_NUM); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus FusedDeepMoeTilingFunc(gert::TilingContext *context) +{ + ge::graphStatus ret = FusedDeepMoeTilingFuncImpl(context); + return ret; +} + +struct FusedDeepMoeCompileInfo {}; +ge::graphStatus TilingParseForFusedDeepMoe(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(FusedDeepMoe) + .Tiling(FusedDeepMoeTilingFunc) + .TilingParse(TilingParseForFusedDeepMoe); +} // namespace optiling diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe.cpp b/csrc/fused_deep_moe/op_kernel/fused_deep_moe.cpp new file mode 100644 index 00000000000..8d25ddb6d44 --- /dev/null +++ b/csrc/fused_deep_moe/op_kernel/fused_deep_moe.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe operator kernel function implementation file + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 create FusedDeepMoe operator kernel function implementation file + */ +#include "fused_deep_moe.h" +#include +#include "lib/matmul_intf.h" + +extern "C" __global__ __aicore__ void fused_deep_moe( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + // output + GM_ADDR output, GM_ADDR outputRecvCount, + // system + GM_ADDR workspace, GM_ADDR tiling) +{ + icache_preload(8); + // New output recvCount + REGISTER_TILING_DEFAULT(FusedDeepMoeTilingData); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V + GET_TILING_DATA(tiling_data, tiling); + if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) { + FusedDeepMoe op; + op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, + expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data); + op.Process(); + } +} diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe.h b/csrc/fused_deep_moe/op_kernel/fused_deep_moe.h new file mode 100644 index 00000000000..2a6cbb68c1b --- /dev/null +++ b/csrc/fused_deep_moe/op_kernel/fused_deep_moe.h @@ -0,0 +1,447 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe operator kernel function header file, for a3 + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 create FusedDeepMoe operator kernel function header file, for a3 + */ +#ifndef FUSED_DEEP_MOE_H +#define FUSED_DEEP_MOE_H + +#include "lib/matmul_intf.h" +#include + +#include "../utils/op_kernel/operator/catlass/act/act.hpp" +#include "../utils/op_kernel/operator/catlass/act/arch/arch.hpp" +#include "../utils/op_kernel/operator/catlass/act/layout/layout.hpp" +#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp" +#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp" +#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp" +#include "../utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp" +#include "../utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp" +#include "../utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp" +#include "../utils/op_kernel/operator/epilogue/dispatch_policy.h" +#include "../utils/op_kernel/operator/gemm/dispatch_policy.h" +#include "../utils/op_kernel/operator/epilogue/block/block_epilogue.h" +#include "../utils/op_kernel/operator/gemm/block/block_mmad.h" +#include "../utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h" + +#include "../utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h" + +#include "fused_deep_moe_tiling.h" +#include "fused_deep_moe_base.h" + +#define ENABLE_GMM2_COMBINE + +using namespace Act; + +using MmadAtlasA2Custom = + Gemm::MmadAtlasA2PreloadAsyncWithCallback; + +using Gmm1L1TileShape = GemmShape; +using Gmm1L0TileShape = GemmShape; +using Gmm1EpilogueTileShape = MatrixShape; +using Gmm1BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; + +using Gmm2L1TileShape = GemmShape; +using Gmm2L0TileShape = GemmShape; +using Gmm2EpilogueTileShape = MatrixShape; +using Gmm2BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; +using Gmm2DispatchPolicy = + Gemm::MmadAtlasA2PreloadAsyncWithCallbackResidentA; + +template +ACT_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, + GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, + GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, + uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, + uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, + uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using XType = XType_; + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + + using GemmKernel = typename std::conditional< + (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), + Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< + XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, + Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace, + gmX, + debugGm, + gmexpertIds, + gmExpandIdx, + gmEpSendCount, + gmResvered, + gmOutputRecvCount, + epRankSize, + epRankId, + moeExpertNum, + moeExpertNumPerRank, + sharedExpertNum, + sharedExpertRankNum, + quantMode, + globalBs, + bs, + topK, + tokenLen}; + // call a kernel + GemmKernel gemm; + gemm(params); + } else { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace}; + // call a kernel + GemmKernel gemm; + gemm(params); + } +} + +template +ACT_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmWorkspace, void *combiner) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequant; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + typename GemmKernel::Params params{ + problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale, + layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner}; + + // call a kernel + GemmKernel gemm; + gemm(params); +} + +template +class FusedDeepMoe +{ +public: + __aicore__ inline FusedDeepMoe(){}; + __aicore__ inline void Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + // output + GM_ADDR output, GM_ADDR outputRecvCount, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData); + __aicore__ inline void Process(); + +private: + GM_ADDR gmX_; + GM_ADDR gmexpertIds_; + GM_ADDR gmPermuteWeight1_; + GM_ADDR gmPermuteScale1_; + GM_ADDR gmWeight2_; + GM_ADDR gmScale2_; + GM_ADDR gmOutput_; + GM_ADDR gmOutputRecvCount_; + GM_ADDR workspaceGM_; + GM_ADDR gmSmoothScales_; + GM_ADDR gmexpertScales_; + + uint32_t m_{0}; + uint32_t n_{0}; + uint32_t k_{0}; + uint32_t groupCount_{0}; + uint32_t n2_{0}; + uint32_t k2_{0}; + uint32_t globalRankId_{0}; + uint32_t winSizePerRank_{0}; + uint32_t blockDim_{0}; + uint32_t epRankSize_{0}; + uint32_t epRankId_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertNumPerRank_{0}; + uint32_t sharedExpertNum_{0}; + uint32_t sharedExpertRankNum_{0}; + uint32_t quantMode_{0}; + uint32_t globalBs_{0}; + uint32_t bs_{0}; + uint32_t maxBs_{0}; + uint32_t topK_{0}; + + AscendC::TPipe *tpipe_{nullptr}; + __gm__ HcclOpResParam *winContext_{nullptr}; + const FusedDeepMoeTilingData *tilingData_; +}; + +template +__aicore__ inline void FusedDeepMoe::Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + // output + GM_ADDR output, GM_ADDR outputRecvCount, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData) +{ + tpipe_ = pipe; + blockDim_ = AscendC::GetBlockNum(); + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + gmSmoothScales_ = expert_smooth_scales; // 这里传入较大空间,开发时使用 + gmX_ = x; // dispatch的输入 + gmexpertIds_ = expert_ids; + gmPermuteWeight1_ = gmm1_permuted_weight; + gmPermuteScale1_ = gmm1_permuted_weight_scale; + gmWeight2_ = gmm2_weight; + gmScale2_ = gmm2_weight_scale; + gmOutput_ = output; + gmOutputRecvCount_ = outputRecvCount; + workspaceGM_ = workspaceGM; + gmexpertScales_ = expert_scales; + tilingData_ = tilingData; + epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + moeExpertNumPerRank_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + sharedExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + quantMode_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.quantMode; + globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + maxBs_ = globalBs_ / epRankSize_; + + bool isShareExpert = (epRankId_ < sharedExpertRankNum_); + if (isShareExpert) { + m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_; + } else { + m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_); + } + + n_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + k_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + n2_ = k_; + k2_ = n_ / 2; +} + +template +__aicore__ inline void FusedDeepMoe::Process() +{ +#ifdef ENABLE_GMM2_COMBINE + if (g_coreType == AscendC::AIV) { + ((FusedDeepMoeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = get_block_num(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + ((FusedDeepMoeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = get_block_num(); + } else { + ((FusedDeepMoeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = + get_block_num() * get_subblockdim(); + } + } +#endif + GemmCoord gmm1ProblemShape{m_, n_, k_}; + GemmCoord gmm2ProblemShape{m_, n2_, k2_}; + + layout::RowMajor layoutX1{m_, k_}; + layout::zN layoutWeight1 = layout::zN::template MakeLayout(k_, n_); + layout::VectorLayout layoutScale1{n_}; + layout::VectorLayout layoutPerTokenScale1{m_}; + layout::RowMajor layoutX2{m_, k2_}; + layout::nZ layoutWeight2 = layout::nZ::template MakeLayout(k2_, n2_); + layout::VectorLayout layoutScale2{n2_}; + layout::VectorLayout layoutPerTokenScale2{m_}; + layout::RowMajor layoutOutput{m_, n2_}; + + size_t workspaceOffset = 0; + constexpr int32_t resveredWorkSpaceSize = 256 * 1024; + GM_ADDR gmX2 = workspaceGM_; + workspaceOffset += RoundUp(static_cast(m_) * k2_ * sizeof(int8_t)); + GM_ADDR gmPerTokenScale2 = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * sizeof(float)); + GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset; + + GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(blockDim_) * (GMM1_L1M * GMM1_L1N) * + WORKSPACE_STAGES * sizeof(int32_t)); + GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * k2_ * sizeof(float)); + GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(groupCount_) * sizeof(int64_t)); + GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(bs_) * topK_ * sizeof(int32_t)); + GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(epRankSize_) * groupCount_ * sizeof(int32_t)); + GM_ADDR gmX1Token = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * k_ * sizeof(int8_t)); + GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * sizeof(float)); + GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * k_ * sizeof(ExpandXType)); + GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(resveredWorkSpaceSize); + + if constexpr (EXEC_FLAG == 0) { + if constexpr (g_coreType == AscendC::AIV) { + AscendC::TPipe tpipe; + MoeDistributeDispatchImpl::CamMoeDistributeDispatch + dispatcher; + dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1Token, gmX1Scale, gmExpandIdx, gmGroupList, + gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); + dispatcher.Process(); + tpipe.Destroy(); + icache_preload(8); + } + + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + } + GmmDeqSwigluQuant( + gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1, + gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, + layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, + gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, + sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, k_); +#ifdef ENABLE_GMM2_COMBINE + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + + MoeDistributeCombineImpl::CamMoeDistributeCombine combiner; + if (g_coreType == AscendC::AIV) { + combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_, + workspaceGM_, nullptr, tilingData_); + } + GmmDeq(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2, + gmScale2_, layoutScale2, gmPerTokenScale2, layoutPerTokenScale2, gmGmm2DepOut, + layoutOutput, gmWorkspace, &combiner); +#endif +} +#endif // FUSED_DEEP_MOE_H diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe_base.h b/csrc/fused_deep_moe/op_kernel/fused_deep_moe_base.h new file mode 100644 index 00000000000..d09d4894adb --- /dev/null +++ b/csrc/fused_deep_moe/op_kernel/fused_deep_moe_base.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: Definition of communication group related structures + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 Create a definition file for a distribution group related structure + */ +#ifndef FUSED_DEEP_MOE_BASE_H +#define FUSED_DEEP_MOE_BASE_H + +#include "moe_distribute_base.h" + +#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG +#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG + +#endif // FUSED_DEEP_MOE_BASE_H diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h b/csrc/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h new file mode 100644 index 00000000000..a4899debceb --- /dev/null +++ b/csrc/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe tilingData definition file + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 create FusedDeepMoe tilingData definition file + */ + +#ifndef FUSED_DEEP_MOE_TILING_H +#define FUSED_DEEP_MOE_TILING_H + +#include "kernel_tiling/kernel_tiling.h" + +struct FusedDeepMoeInfo { + uint32_t epRankSize; // epRankSize + uint32_t epRankId; // epRankId + uint32_t moeExpertNum; // moe expert number + uint32_t moeExpertNumPerRank; // moe expert number per rank + uint32_t sharedExpertNum; // shared expert number + uint32_t sharedExpertRankNum; // shared expert rank number + uint32_t quantMode; // quant mode + uint32_t globalBs; // globalBs = BS * worldSize + uint32_t bs; // bs + uint32_t k; // k + uint32_t h; // h + uint32_t aicNum; // aivNum + uint32_t aivNum; // aivNum + uint64_t totalUbSize; + uint64_t totalWinSize; + uint64_t gmm1HLen; +}; + +struct FusedDeepMoeTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling; + FusedDeepMoeInfo disGmmDeqSwigluQuantGmmDeqComInfo; +}; + +constexpr uint32_t GM_ALIGN_BYTE = 512; +constexpr uint32_t CUSTOM_PRELOAD_STAGES = 1; +constexpr uint32_t CUSTOM_L1_STAGES = 2; +constexpr uint32_t CUSTOM_L0A_STAGES = 2; +constexpr uint32_t CUSTOM_L0B_STAGES = 2; +constexpr uint32_t CUSTOM_L0C_STAGES = 1; +constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; +constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; + +constexpr uint32_t GMM1_L1M = 256; +constexpr uint32_t GMM1_L1N = 128; +constexpr uint32_t GMM1_L1K = 512; +constexpr uint32_t GMM1_L0K = 128; +constexpr uint32_t GMM1_EPIM = 64; +constexpr uint32_t GMM1_SWIZZLE_OFFSET = 3; +constexpr uint32_t GMM1_SWIZZLE_DIRECTION = 0; + +constexpr uint32_t GMM2_L1A_STAGES = 4; +constexpr uint32_t GMM2_L1B_STAGES = 2; +constexpr uint32_t GMM2_L0A_STAGES = 4; +constexpr uint32_t GMM2_L0B_STAGES = 2; +constexpr uint32_t GMM2_L1M = 128; +constexpr uint32_t GMM2_L1N = 256; +constexpr uint32_t GMM2_L1K = 512; +constexpr uint32_t GMM2_L0K = 128; +constexpr uint32_t GMM2_EPIM = 32; +constexpr uint32_t GMM2_SWIZZLE_OFFSET = 3; +constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0; + +constexpr uint32_t WORKSPACE_STAGES = 4; + +constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); + +#endif // FUSED_DEEP_MOE_TILING_H diff --git a/csrc/utils/op_host/error_log.h b/csrc/utils/op_host/error_log.h new file mode 100644 index 00000000000..d809a922658 --- /dev/null +++ b/csrc/utils/op_host/error_log.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: create log implementation file + * Author: Han Jiahui + * Create: 2025-05-21 + * Note: + * History: 2025-05-21 create log implementation file + */ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + printf("[WARN]" __VA_ARGS__); \ + printf("\n") +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + printf("[ERRORx]" __VA_ARGS__); \ + printf("\n") +#define OP_LOGE(opname, ...) \ + printf("[ERROR]" __VA_ARGS__); \ + printf("\n") +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + +#define OP_TILING_CHECK(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h b/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h new file mode 100644 index 00000000000..b35e6140183 --- /dev/null +++ b/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h @@ -0,0 +1,812 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: add combine kernel implement + * Author: Chen Cheng + * Create: 2025-07-21 + * Note: + * History: 2025-07-21 add combine kernel implement + */ +#ifndef CAM_MOE_DISTRIBUTE_COMBINE_H +#define CAM_MOE_DISTRIBUTE_COMBINE_H +#define OPT_RANK_OFFSET 512 + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../../../../../../op_kernel/fused_deep_moe_base.h" +#include "../../../../../../op_kernel/fused_deep_moe_tiling.h" + +namespace MoeDistributeCombineImpl { +constexpr uint8_t BUFFER_NUM = 2; // multi-buf +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M +constexpr uint32_t RANK_SIZE_ON_WIN_512 = 512 * 1024; +constexpr uint32_t RANK_SIZE_ON_WIN_256 = 256 * 1024; +constexpr uint32_t TP_RANK_SIZE_ON_WIN = 0; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint8_t EP_DOMAIN = 0; +constexpr uint8_t TP_DOMAIN = 1; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint16_t SEND_SYNC_EVENT_ID = 9; +constexpr uint16_t RECV_SYNC_EVENT_ID = 10; + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +using namespace AscendC; + +struct CombineCalcInfo { + uint64_t expertPerSizeOnWin_; + uint32_t epRankId_; + uint32_t epWorldSize_; + uint32_t moeExpertPerRankNum_; + uint32_t sharedExpertRankNum_; + uint32_t axisH_; + uint32_t moeSendNum_; + bool isShardExpert_; + GM_ADDR epSendCount_; + __gm__ HcclOpResParam *epWinContext_; + uint64_t winDataSizeOffset_; +}; + +template +class CamMoeDistributeCombine +{ +public: + __aicore__ inline CamMoeDistributeCombine(){}; + __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, + GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, + const FusedDeepMoeTilingData *tilingData); + __aicore__ inline void Process(); + __aicore__ inline void AllToAllSend(); + __aicore__ inline void ReducePermute(); + + __aicore__ inline CombineCalcInfo &GetCalcInfo() + { + return calcInfo_; + } + + __aicore__ inline void TPipeSet(AscendC::TPipe *pipe) + { + tpipe_ = pipe; + } + +private: + __aicore__ inline void InitStatusTargetSum(); + __aicore__ inline void AlltoAllBuffInit(); + __aicore__ inline void ReduceScatterTrans(); + __aicore__ inline void SetWaitTpStatusAndDisPatch(); + __aicore__ inline void CustomAdd(LocalTensor &dst, LocalTensor &src0, + LocalTensor &src1, uint32_t dataCnt); + __aicore__ inline void ExpertAlltoAllDispatchInnerCopyAdd(uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, + uint32_t ep, uint32_t expertIdx); + __aicore__ inline void ExpertAlltoAllDispatchCopyAdd(); + __aicore__ inline void LocalWindowCopy(); + __aicore__ inline void BuffInit(); + __aicore__ inline void SplitCoreCal(); + __aicore__ inline void SetStatus(); + __aicore__ inline void WaitDispatch(); + __aicore__ GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t domain, const uint8_t expertLocalId = 0U) + { + if (domain == EP_DOMAIN) { + return (GM_ADDR)((epRankId_ == rankId) + ? epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + winDataSizeOffset_ + expertLocalId * expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; + } else { + return (GM_ADDR)((tpRankId_ == rankId) + ? tpWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + } + + __aicore__ GM_ADDR GetWinStateAddrByRankId(const int32_t rankId, const uint8_t domain) + { + if (domain == EP_DOMAIN) { + return (GM_ADDR)((epRankId_ == rankId) + ? epWinContext_->localWindowsExp + : ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } else { + return (GM_ADDR)((tpRankId_ == rankId) + ? tpWinContext_->localWindowsExp + : ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } + } + + __aicore__ inline uint32_t MIN(uint32_t x, uint32_t y) + { + return (x < y) ? x : y; + } + + __aicore__ static void DoCombineRecv(void *ptr) + { + auto *combiner = (CamMoeDistributeCombine *)ptr; + combiner->ReducePermute(); + } + + TPipe *tpipe_{nullptr}; + GlobalTensor expandXGM_; + GlobalTensor expertIdsGM_; + GlobalTensor expandIdxGM_; + GlobalTensor epSendCountGM_; + GlobalTensor tpSendCountGM_; + GlobalTensor expandScalesGM_; + GlobalTensor expandOutGlobal_; + GlobalTensor rankWindow_; + GlobalTensor rankStates_; + GlobalTensor epStatusSpaceGlobalTensor_; + GlobalTensor tpStatusSpaceGlobalTensor_; + GlobalTensor tpRankWindow_; + GlobalTensor rowTmpGlobal_; + GM_ADDR workspaceGM_; + GM_ADDR epWindowGM_; + GM_ADDR epStatusSpaceGm_; + GM_ADDR tpWindowGM_; + GM_ADDR tpStatusSpaceGm_; + GM_ADDR stateGM_; + + LocalTensor winTpSendCountTensor_; + LocalTensor gmTpSendCountTensor_; + LocalTensor outTensor_; + LocalTensor winTpSendCountFloatTensor_; + LocalTensor gmTpSendCountFloatTensor_; + LocalTensor epSendCountLocal_; + + CombineCalcInfo calcInfo_; + uint32_t axisBS_{0}; + uint32_t axisMaxBs_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t aivNum_{0}; + uint32_t epWorldSize_{0}; + uint32_t tpWorldSize_{0}; + uint32_t epRankId_{0}; + uint32_t tpRankId_{0}; + uint32_t coreIdx_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertPerRankNum_{0}; + uint32_t moeSendNum_{0}; // moeExpertPerRankNum_ * epWorldSize_ + uint32_t tpScatterNum_{0}; + uint32_t firstTpTokenEndIdx_{0}; + uint32_t firstTpTokenEndOffset_{0}; + uint32_t endTok_{0}; + __gm__ HcclOpResParam *epWinContext_{nullptr}; + __gm__ HcclOpResParam *tpWinContext_{nullptr}; + uint32_t epDataOffsetOnWin_{0}; + uint32_t tpDataOffsetOnWin_{0}; + uint32_t epStateOffsetOnWin_{0}; + uint32_t tpStateOffsetOnWin_{0}; + uint32_t axisHFloatSize_{0}; + uint32_t axisHExpandXTypeSize_{0}; + uint32_t bsKNum_{0}; + uint32_t startRankId_{0}; + uint32_t endRankId_{0}; + uint32_t sendRankNum_{0}; + uint32_t ubSize_{0}; + uint32_t dataState_{0}; + uint32_t stateOffset_{0}; + uint64_t winDataSizeOffset_{0}; + uint64_t expertPerSizeOnWin_{0}; + uint64_t totalWinSize_{0}; + TQueBind moeQueue_; + TQue moeSumQueue_; + TQueBind gmTpSendCountQueue_; + TQue gmTpSendCountInQueue_; + TQue winTpSendCountInQueue_; + TQue xOutQueue_; + TBuf<> readStateBuf_; + TBuf<> expertIdsBuf_; + TBuf<> expandScalesBuf_; + TBuf<> rowTmpFloatBuf_; + TBuf<> sumFloatBuf_; + TBuf<> mulBuf_; + TBuf<> sendCountBuf_; + TBuf<> indexCountsBuf_; + TBuf<> winTpSendCountFloatBuf_; + TBuf<> gmTpSendCountFloatBuf_; + TBuf<> tokenBuf_; + TBuf<> statusBuf_; + TBuf<> gatherMaskOutBuf_; // gather mask output buf + TBuf<> gatherTmpBuf_; + TBuf<> statusSumOutBuf_; + float sumTarget_{0.0}; + int32_t epStateValue_; + bool isShardExpert_{false}; +}; + +template +__aicore__ inline void CamMoeDistributeCombine::Init( + GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, + GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const FusedDeepMoeTilingData *tilingData) +{ + tpipe_ = pipe; + coreIdx_ = GetBlockIdx(); + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + auto contextGM0 = AscendC::GetHcclContext(); + epWinContext_ = (__gm__ HcclOpResParam *)contextGM0; + GlobalTensor selfDataStatusTensor; + GM_ADDR statusDataSpaceGm = (GM_ADDR)epWinContext_->localWindowsExp; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + dataState_ = selfDataStatusTensor(coreIdx_ * UB_ALIGN); + if (dataState_ == 0) { + selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + pipe_barrier(PIPE_ALL); + + workspaceGM_ = workspaceGM; + expandXGM_.SetGlobalBuffer((__gm__ ExpandXType *)expandX); + expertIdsGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expertIds); + expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx); + epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount); + expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales); + expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut); + axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + aivNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum; + ubSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalUbSize; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + moeExpertPerRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + axisMaxBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs / epWorldSize_; + moeSendNum_ = epWorldSize_ * moeExpertPerRankNum_; + tpWorldSize_ = 1; + tpRankId_ = 0; + totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize; + stateOffset_ = (moeSendNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET; + expertPerSizeOnWin_ = + static_cast(axisMaxBs_) * static_cast(axisH_) * static_cast(sizeof(ExpandXType)); + winDataSizeOffset_ = static_cast(dataState_) * static_cast(moeSendNum_) * expertPerSizeOnWin_; + epWindowGM_ = GetWinAddrByRankId(epRankId_, EP_DOMAIN); + epStatusSpaceGm_ = GetWinStateAddrByRankId(epRankId_, EP_DOMAIN); + epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)epStatusSpaceGm_); + epDataOffsetOnWin_ = epRankId_ * moeExpertPerRankNum_ * static_cast(expertPerSizeOnWin_); + epStateOffsetOnWin_ = epRankId_ * stateOffset_; + isShardExpert_ = (epRankId_ < sharedExpertRankNum_); + axisHFloatSize_ = axisH_ * sizeof(float); + axisHExpandXTypeSize_ = axisH_ * sizeof(ExpandXType); + bsKNum_ = axisBS_ * axisK_; + + if constexpr (IsNeedReduceScatter) { + tpSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)tpSendCount); + tpWindowGM_ = GetWinAddrByRankId(tpRankId_, TP_DOMAIN); + tpStatusSpaceGm_ = GetWinStateAddrByRankId(tpRankId_, TP_DOMAIN); + tpStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)tpStatusSpaceGm_); + tpDataOffsetOnWin_ = tpRankId_ * TP_RANK_SIZE_ON_WIN; + tpStateOffsetOnWin_ = tpRankId_ * stateOffset_; + uint32_t tpScatterRankWinOffset = (tpRankId_ == 0) ? TP_RANK_SIZE_ON_WIN : 0; + GM_ADDR rankGM = tpWindowGM_ + tpScatterRankWinOffset; + tpRankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + } + + InitStatusTargetSum(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + coreIdx_ = get_block_idx(); + } + SplitCoreCal(); + + calcInfo_.epRankId_ = epRankId_; + calcInfo_.epWorldSize_ = epWorldSize_; + calcInfo_.expertPerSizeOnWin_ = expertPerSizeOnWin_; + calcInfo_.moeExpertPerRankNum_ = moeExpertPerRankNum_; + calcInfo_.sharedExpertRankNum_ = sharedExpertRankNum_; + calcInfo_.axisH_ = axisH_; + calcInfo_.moeSendNum_ = moeSendNum_; + calcInfo_.isShardExpert_ = isShardExpert_; + calcInfo_.epSendCount_ = epSendCount; + calcInfo_.epWinContext_ = epWinContext_; + calcInfo_.winDataSizeOffset_ = winDataSizeOffset_; +} + +template +__aicore__ inline void CamMoeDistributeCombine::InitStatusTargetSum() +{ + // ep state + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(epStatusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + int32_t state = selfStatusTensor(coreIdx_ * UB_ALIGN); + if (state == 0) { + sumTarget_ = static_cast(1.0); + selfStatusTensor(coreIdx_ * UB_ALIGN) = 0x3F800000; // 1.0f + epStateValue_ = 0x3F800000; // 1.0f + } else { + sumTarget_ = static_cast(0.0); + selfStatusTensor(coreIdx_ * UB_ALIGN) = 0; + epStateValue_ = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); +} + +template +__aicore__ inline void CamMoeDistributeCombine::BuffInit() +{ + tpipe_->Reset(); + tpipe_->InitBuffer(readStateBuf_, UB_ALIGN); // 32 + uint32_t sendNumAlign = Ceil(moeSendNum_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; + tpipe_->InitBuffer(sendCountBuf_, sendNumAlign); // epWorldSize_ * moeExpertPerRankNum_ * 4 + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(winTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + tpipe_->InitBuffer(gmTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + if constexpr (AscendC::IsSameType::value) { + tpipe_->InitBuffer(winTpSendCountFloatBuf_, axisHFloatSize_); + tpipe_->InitBuffer(gmTpSendCountFloatBuf_, axisHFloatSize_); + winTpSendCountFloatTensor_ = winTpSendCountFloatBuf_.Get(); + gmTpSendCountFloatTensor_ = gmTpSendCountFloatBuf_.Get(); + } + } else { + tpipe_->InitBuffer(gmTpSendCountQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + } + epSendCountLocal_ = sendCountBuf_.Get(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::AlltoAllBuffInit() +{ + tpipe_->Reset(); + uint32_t bsMulTopkSizeAligned = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; // 防止UB不对齐 + tpipe_->InitBuffer(readStateBuf_, UB_ALIGN); + tpipe_->InitBuffer(statusBuf_, sendRankNum_ * UB_ALIGN); + tpipe_->InitBuffer(expertIdsBuf_, bsMulTopkSizeAligned); + tpipe_->InitBuffer(expandScalesBuf_, bsMulTopkSizeAligned); + tpipe_->InitBuffer(tokenBuf_, axisH_ * sizeof(ExpandXType)); + tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(mulBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(indexCountsBuf_, bsMulTopkSizeAligned); + tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, axisHExpandXTypeSize_); + tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float)); + tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t)); // 4 + tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float)); // 4 +} + +template +__aicore__ inline void CamMoeDistributeCombine::SplitCoreCal() +{ + sendRankNum_ = epWorldSize_ / aivNum_; + uint32_t remainderRankNum = epWorldSize_ % aivNum_; + startRankId_ = sendRankNum_ * coreIdx_; + if (coreIdx_ < remainderRankNum) { + sendRankNum_++; + startRankId_ += coreIdx_; + } else { + startRankId_ += remainderRankNum; + } + endRankId_ = startRankId_ + sendRankNum_; +} + +template +__aicore__ inline void CamMoeDistributeCombine::ReduceScatterTrans() +{ + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(tpSendCountGM_[tpRankId_]); + __asm__ __volatile__(""); + uint32_t offset = tpSendCountGM_.GetValue(tpRankId_) * axisH_; + GlobalTensor dataCopyInGM = expandXGM_[offset]; + GM_ADDR rankGM = GetWinAddrByRankId(1 - tpRankId_, TP_DOMAIN) + tpDataOffsetOnWin_; + rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + uint32_t copyStartIdx = 0; + if (startRankId_ > 0) { + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + epSendCountGM_[epWorldSize_ + startRankId_ - 1]); + __asm__ __volatile__(""); + copyStartIdx = epSendCountGM_.GetValue(epWorldSize_ + startRankId_ - 1); + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + epSendCountGM_[epWorldSize_ + endRankId_ - 1]); + __asm__ __volatile__(""); + uint32_t copyEndIdx = epSendCountGM_.GetValue(epWorldSize_ + endRankId_ - 1); + LocalTensor tmpUb; + for (uint32_t tokenNumIdx = copyStartIdx; tokenNumIdx < copyEndIdx; tokenNumIdx++) { + tmpUb = moeQueue_.AllocTensor(); + DataCopy(tmpUb, dataCopyInGM[tokenNumIdx * axisH_], axisH_); + moeQueue_.EnQue(tmpUb); + tmpUb = moeQueue_.DeQue(); + DataCopy(rankWindow_[tokenNumIdx * axisH_], tmpUb, axisH_); + moeQueue_.FreeTensor(tmpUb); + } +} + +// 46 -> gm -> ub syncall win->gm add -> alltoall +// 2 -> win wait syncall gm -> ub win ->gm add -> alltoall +template +__aicore__ inline void CamMoeDistributeCombine::SetWaitTpStatusAndDisPatch() +{ + pipe_barrier(PIPE_ALL); + if (startRankId_ >= epWorldSize_) { + return; + } + if constexpr (IsNeedReduceScatter) { + uint32_t tpToRankId = 1 - tpRankId_; + pipe_barrier(PIPE_ALL); + LocalTensor statusFlagUb = readStateBuf_.Get(); + statusFlagUb(0) = sumTarget_; + SyncFunc(); + GlobalTensor tpWindowInstatusFp32Tensor_; + stateGM_ = GetWinStateAddrByRankId(tpToRankId, TP_DOMAIN) + coreIdx_ * stateOffset_; + tpWindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)stateGM_); + DataCopy(tpWindowInstatusFp32Tensor_, statusFlagUb, 8UL); + SyncFunc(); + LocalTensor statusFp32Tensor_ = readStateBuf_.Get(); + float sumOfFlag = static_cast(-1.0); + uint32_t statusRankOffset = coreIdx_ * stateOffset_ / sizeof(float); // tp = 2 + while (sumOfFlag != sumTarget_) { + DataCopy(statusFp32Tensor_, tpStatusSpaceGlobalTensor_[statusRankOffset], 8); + SyncFunc(); + sumOfFlag = statusFp32Tensor_.GetValue(0); + SyncFunc(); + } + } + // Copy win gm->ub add ->alltoall send + ExpertAlltoAllDispatchCopyAdd(); + SyncFunc(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::ExpertAlltoAllDispatchCopyAdd() +{ + if (startRankId_ >= epWorldSize_) { + return; + } + uint32_t curRankExpertNum = 0; + DataCopyExtParams epSendCntParams; + if (isShardExpert_) { + curRankExpertNum = 1; + epSendCntParams = {1U, static_cast(epWorldSize_ * sizeof(uint32_t)), 0U, 0U, 0U}; + } else { + curRankExpertNum = moeExpertPerRankNum_; + epSendCntParams = {1U, static_cast(moeSendNum_ * sizeof(uint32_t)), 0U, 0U, 0U}; + } + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(epSendCountLocal_, epSendCountGM_, epSendCntParams, copyPadParams); + SyncFunc(); + uint32_t preCount = 0; + uint32_t startTokenIdx = 0; + uint32_t curTokenNum = 0; + + for (uint32_t expertIdx = 0U; expertIdx < curRankExpertNum; expertIdx++) { + uint32_t sendEpCount = endRankId_ - startRankId_; + for (uint32_t i = 0; i < sendEpCount; ++i) { + uint32_t ep = startRankId_ + (i + epRankId_) % sendEpCount; + if ((ep > 0) || (expertIdx > 0U)) { + preCount = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep - 1); + } else { + preCount = 0; + } + curTokenNum = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep) - preCount; + if (curTokenNum == 0) { + continue; + } + startTokenIdx = preCount * axisH_; + ExpertAlltoAllDispatchInnerCopyAdd(curTokenNum, startTokenIdx, ep, expertIdx); + } + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::ExpertAlltoAllDispatchInnerCopyAdd( + uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, uint32_t ep, uint32_t expertIdx) +{ + GM_ADDR rankGM = GetWinAddrByRankId(ep, EP_DOMAIN, expertIdx) + epDataOffsetOnWin_; + if ((isShardExpert_) && (ep < sharedExpertRankNum_)) { + rankGM = GetWinAddrByRankId(epRankId_, EP_DOMAIN, expertIdx) + ep * moeExpertPerRankNum_ * expertPerSizeOnWin_; + } + rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + uint32_t dataCnt = axisH_; + for (uint32_t loopIdx = 0; loopIdx < tokenNumLoop; loopIdx++) { + if constexpr (IsNeedReduceScatter) { + gmTpSendCountTensor_ = gmTpSendCountInQueue_.AllocTensor(); + DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt); + gmTpSendCountInQueue_.EnQue(gmTpSendCountTensor_); + + winTpSendCountTensor_ = winTpSendCountInQueue_.AllocTensor(); + DataCopy(winTpSendCountTensor_, tpRankWindow_[srcStartTokenIdx], dataCnt); + winTpSendCountInQueue_.EnQue(winTpSendCountTensor_); + + gmTpSendCountTensor_ = gmTpSendCountInQueue_.DeQue(); + winTpSendCountTensor_ = winTpSendCountInQueue_.DeQue(); + outTensor_ = xOutQueue_.AllocTensor(); + + CustomAdd(outTensor_, winTpSendCountTensor_, gmTpSendCountTensor_, dataCnt); + gmTpSendCountInQueue_.FreeTensor(gmTpSendCountTensor_); + winTpSendCountInQueue_.FreeTensor(winTpSendCountTensor_); + xOutQueue_.EnQue(outTensor_); + + outTensor_ = xOutQueue_.DeQue(); + DataCopy(rankWindow_[loopIdx * dataCnt], outTensor_, dataCnt); + xOutQueue_.FreeTensor(outTensor_); + } else { + gmTpSendCountTensor_ = gmTpSendCountQueue_.AllocTensor(); + DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt); + ExpandXType val = expandXGM_[srcStartTokenIdx].GetValue(0); + gmTpSendCountQueue_.EnQue(gmTpSendCountTensor_); + gmTpSendCountTensor_ = gmTpSendCountQueue_.DeQue(); + DataCopy(rankWindow_[loopIdx * dataCnt], gmTpSendCountTensor_, dataCnt); + gmTpSendCountQueue_.FreeTensor(gmTpSendCountTensor_); + } + srcStartTokenIdx += dataCnt; + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::CustomAdd(LocalTensor &dst, + LocalTensor &src0, + LocalTensor &src1, + uint32_t dataCnt) +{ + if constexpr (AscendC::IsSameType::value) { + Cast(winTpSendCountFloatTensor_, src0, RoundMode::CAST_NONE, dataCnt); + Cast(gmTpSendCountFloatTensor_, src1, RoundMode::CAST_NONE, dataCnt); + pipe_barrier(PIPE_V); + Add(winTpSendCountFloatTensor_, winTpSendCountFloatTensor_, gmTpSendCountFloatTensor_, dataCnt); + pipe_barrier(PIPE_V); + Cast(dst, winTpSendCountFloatTensor_, RoundMode::CAST_ROUND, dataCnt); + } else { + Add(dst, src0, src1, dataCnt); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::SetStatus() +{ + pipe_barrier(PIPE_ALL); + if (startRankId_ >= epWorldSize_) { + return; + } + + LocalTensor statusFlagUb = readStateBuf_.Get(); + statusFlagUb.SetValue(0, epStateValue_); + SyncFunc(); + + for (uint32_t epIdx = startRankId_; epIdx < endRankId_; epIdx++) { + stateGM_ = GetWinStateAddrByRankId(epIdx, EP_DOMAIN) + epStateOffsetOnWin_; + rankStates_.SetGlobalBuffer((__gm__ int32_t *)stateGM_); + DataCopy(rankStates_, statusFlagUb, 8); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::WaitDispatch() +{ + if (startRankId_ < epWorldSize_) { + LocalTensor statusTensor = statusBuf_.Get(); + LocalTensor gatherMaskOutTensor = gatherMaskOutBuf_.Get(); + LocalTensor gatherTmpTensor = gatherTmpBuf_.Get(); + LocalTensor statusSumOutTensor = statusSumOutBuf_.Get(); + PipeBarrier(); + + gatherTmpTensor.SetValue(0, 1); + uint32_t mask = 1; // gatherMask + sum + uint64_t rsvdCnt = 0; + DataCopyParams intriParams{static_cast(sendRankNum_), 1, + static_cast((moeSendNum_ > 512) ? 7 : 15), 0}; // srcStride is 15 blocks + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5; + float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5; + SumParams sumParams{1, sendRankNum_, sendRankNum_}; + SyncFunc(); + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + DataCopy(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)], + intriParams); + SyncFunc(); + GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask, + {1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt); + PipeBarrier(); + Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + SyncFunc(); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } else { + SyncAll(); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::LocalWindowCopy() +{ + uint32_t beginIndex = 0; + uint32_t endIndex = 0; + uint32_t processLen = 0; + uint32_t tokenOffset = 0; + if (axisBS_ < aivNum_) { + uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_ + if (coreIdx_ >= (axisBS_ * aivNumPerToken)) { + return; + } + uint32_t tokenIndex = coreIdx_ / aivNumPerToken; + processLen = ((axisH_ / UB_ALIGN) / aivNumPerToken) * UB_ALIGN; + tokenOffset = processLen * (coreIdx_ % aivNumPerToken); + if ((coreIdx_ % aivNumPerToken) == (aivNumPerToken - 1)) { + processLen = axisH_ - ((aivNumPerToken - 1) * processLen); + } + beginIndex = tokenIndex; + endIndex = beginIndex + 1U; + } else { + uint32_t tokenPerAivNum = axisBS_ / aivNum_; + uint32_t remainderToken = axisBS_ % aivNum_; + beginIndex = tokenPerAivNum * coreIdx_; + if (coreIdx_ < remainderToken) { + tokenPerAivNum++; + beginIndex = tokenPerAivNum * coreIdx_; + } else { + beginIndex += remainderToken; + } + endIndex = beginIndex + tokenPerAivNum; + processLen = axisH_; + } + LocalTensor expertIdsLocal = expertIdsBuf_.Get(); + LocalTensor expandScalesLocal = expandScalesBuf_.Get(); + + LocalTensor rowTmpFloatLocal = rowTmpFloatBuf_.Get(); + LocalTensor mulBufLocal = mulBuf_.Get(); + LocalTensor sumFloatBufLocal = sumFloatBuf_.Get(); + + LocalTensor indexCountsLocal = indexCountsBuf_.Get(); + const DataCopyExtParams bskParams = {1U, static_cast(bsKNum_ * sizeof(uint32_t)), 0U, 0U, 0U}; + const DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + const DataCopyPadExtParams copyPadFloatParams{false, 0U, 0U, 0U}; + + DataCopyPad(indexCountsLocal, expandIdxGM_, bskParams, copyPadParams); + DataCopyPad(expertIdsLocal, expertIdsGM_, bskParams, copyPadParams); + DataCopyPad(expandScalesLocal, expandScalesGM_, bskParams, copyPadFloatParams); + SyncFunc(); + + for (uint32_t tokenIndex = beginIndex; tokenIndex < endIndex; tokenIndex++) { + uint32_t index = tokenIndex * axisK_; + SyncFunc(); + Duplicate(sumFloatBufLocal, (float)0, axisH_); + for (uint32_t i = 0; i < axisK_; i++) { + int32_t moeExpert = expertIdsLocal.GetValue(index); + if (moeExpert < 0) { + index++; + continue; + } + float scaleVal = expandScalesLocal.GetValue(index); + GM_ADDR wAddr = (__gm__ uint8_t *)(epWindowGM_) + + expertPerSizeOnWin_ * moeExpertPerRankNum_ * sharedExpertRankNum_ + + expertPerSizeOnWin_ * moeExpert + indexCountsLocal.GetValue(index) * axisHExpandXTypeSize_ + + tokenOffset * sizeof(ExpandXType); + rowTmpGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)wAddr); + ExpandXType val = rowTmpGlobal_.GetValue(0); + LocalTensor tmpUb = moeSumQueue_.AllocTensor(); + DataCopy(tmpUb, rowTmpGlobal_, processLen); + moeSumQueue_.EnQue(tmpUb); + tmpUb = moeSumQueue_.DeQue(); + Cast(rowTmpFloatLocal, tmpUb, AscendC::RoundMode::CAST_NONE, processLen); + AscendC::PipeBarrier(); + AscendC::Muls(mulBufLocal, rowTmpFloatLocal, scaleVal, processLen); + AscendC::PipeBarrier(); + AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, mulBufLocal, processLen); + index++; + moeSumQueue_.FreeTensor(tmpUb); + } + LocalTensor rowTmpLocal = tokenBuf_.Get(); + if (sharedExpertRankNum_ > 0U) { + uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; + uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - + epRankId_ * axisBS_ / sharedExpertRankNum_; + __gm__ ExpandXType *shareAddr = + (__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) + + (tokenIndex - preCnt) * axisH_ + tokenOffset; + GlobalTensor shareTokGlobal; + shareTokGlobal.SetGlobalBuffer((__gm__ ExpandXType *)(shareAddr)); + SyncFunc(); + DataCopy(rowTmpLocal, shareTokGlobal, processLen); + SyncFunc(); + Cast(rowTmpFloatLocal, rowTmpLocal, AscendC::RoundMode::CAST_NONE, processLen); + AscendC::PipeBarrier(); + AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, rowTmpFloatLocal, processLen); + } + // 结果搬出 + AscendC::PipeBarrier(); + LocalTensor sumBufLocal = tokenBuf_.Get(); + Cast(sumBufLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, processLen); + SyncFunc(); + DataCopy(expandOutGlobal_[tokenIndex * axisH_ + tokenOffset], sumBufLocal, processLen); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::Process() +{ + SyncAll(); + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 7168 * 2 * 2 = 28672 + ReduceScatterTrans(); + } + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + BuffInit(); + SetWaitTpStatusAndDisPatch(); + } + AlltoAllBuffInit(); + SetStatus(); + WaitDispatch(); + LocalWindowCopy(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::AllToAllSend() +{ + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 7168 * 2 * 2 = 28672 + ReduceScatterTrans(); + } + BuffInit(); + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + SetWaitTpStatusAndDisPatch(); + AlltoAllBuffInit(); + } + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } else { + SyncAll(); + } + SetStatus(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } else { + SyncAll(); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::ReducePermute() +{ + AlltoAllBuffInit(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + } else { + SyncAll(); + } + + WaitDispatch(); + LocalWindowCopy(); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } +} +} // namespace MoeDistributeCombineImpl + +#endif // CAM_MOE_DISTRIBUTE_COMBINE_IMPL_H diff --git a/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h b/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h new file mode 100644 index 00000000000..cf608bb4083 --- /dev/null +++ b/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h @@ -0,0 +1,1090 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: CamMoeDistributeDispatch operator kernel function header file, for a3 + * Author: WANG Qiankun + * Create: 2025-05-29 + * Note: + * History: 2025-05-29 create CamMoeDistributeDispatch operator kernel function header file, for a3 + */ + +#ifndef CAM_MOE_DISTRIBUTE_DISPATCH_H +#define CAM_MOE_DISTRIBUTE_DISPATCH_H +#define OPT_RANK_OFFSET 512 + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../../../../../../op_kernel/fused_deep_moe_base.h" +#include "../../../../../../op_kernel/fused_deep_moe_tiling.h" + +namespace MoeDistributeDispatchImpl { +constexpr uint8_t BUFFER_NUM = 2; // 多buf +constexpr uint32_t STATE_OFFSET = 512; // 状态空间偏移地址 +constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M +constexpr uint32_t UB_ALIGN = 32; // UB按32字节对齐 +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint8_t COMM_NUM = 2; // 通信域大小 +constexpr uint8_t COMM_EP_IDX = 0; +constexpr uint8_t COMM_TP_IDX = 1; +constexpr uint32_t GATHER_NUM_PER_TIME = 6; +// 先写死这个偏移,如果TP固定为2,可直接往起始数据偏移开始读写 +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint32_t TP_STATE_SIZE = 100 * 1024; +constexpr int CAM_MAX_RANK_SIZE = 384; // Cam通信库最大支持的npu卡数 +constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // 前2MB作为flag标志位,之后100MB作为数据存储 + +// 循环优化相关变量 +using countType = uint8_t; // 循环优化使用的数据类型 +constexpr uint32_t LOOP_OPT_MAX_BS = 64; +constexpr uint32_t LOOP_OPT_MAX_MOE_RANK = 256; +constexpr uint32_t TOPK_ELEM_COUNT_PER_BLOCK = UB_ALIGN / sizeof(int32_t); +constexpr uint32_t TABLE_ELEM_COUNT_PER_BLOCK = UB_ALIGN / sizeof(countType); +constexpr uint32_t INT32_NUM_PER_BLOCK = UB_ALIGN / sizeof(int32_t); + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +#define TemplateDispatchTypeClass \ + typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ + bool IsNeedAllgater +#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater + +using namespace AscendC; +template +class CamMoeDistributeDispatch +{ +public: + __aicore__ inline CamMoeDistributeDispatch(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, + GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, + GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, + GM_ADDR workspaceGM, TPipe *pipe, const FusedDeepMoeTilingData *tilingData); + __aicore__ inline void Process(); + +private: + __aicore__ inline void SendToSharedExpert(); + __aicore__ inline void SendToMoeExpert(); + __aicore__ inline void AlltoAllDispatch(); + __aicore__ inline void LocalWindowCopy(); + __aicore__ inline void QuantProcess(uint32_t expertIndex); + __aicore__ inline void LocalSharedExpertCopyWindow(uint32_t rankIndex, uint32_t tokenOffset, + uint32_t currendTokenIndex, uint32_t &dynamicScalesLocalIdx); + __aicore__ inline void SetStatus(); + __aicore__ inline void WaitDispatch(); + __aicore__ inline void GetCumSum(LocalTensor &inLocal, LocalTensor &outLocal, int32_t totalCount, + GM_ADDR gmOutputRecvCount); + __aicore__ inline void CreateZeroTensor(LocalTensor &outTensor); + __aicore__ inline void AllGatherSetStatusAndWait(); + __aicore__ inline void ResetStatus(); + __aicore__ inline void QuantInit(GM_ADDR scales); + __aicore__ inline void AllgatherProcessOut(); + __aicore__ inline void UpdataMultiMoeTokenNumsOut(); + __aicore__ inline void UpdataTokenNumsOut(); + __aicore__ inline GM_ADDR GetWindAddrByRankId(uint8_t ctxIdx, const int32_t rankId) + { + uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId_ : tpRankId_; + if (curRankId == rankId) { + return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) + + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + + __aicore__ inline GM_ADDR GetWindStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId) + { + uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId_ : tpRankId_; + if (curRankId == rankId) { + return (GM_ADDR)(winContext_[ctxIdx]->localWindowsExp) + dataState_ * WIN_STATE_OFFSET; + } + return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } + + __aicore__ inline uint32_t MIN(uint32_t x, uint32_t y) + { + return (x < y) ? x : y; + } + TPipe *tpipe_{nullptr}; + GlobalTensor xGMTensor_; + GlobalTensor expertIdsGMTensor_; + GlobalTensor scalesGMTensor_; + GlobalTensor expandXOutGMTensor_; + GlobalTensor dynamicScalesOutGMTensor_; + GlobalTensor expertTokenNumsOutGMTensor_; + GlobalTensor windowInQuantTensor_; + GlobalTensor windowInstatusTensor_; + GlobalTensor windowInstatusFp32Tensor_; + GlobalTensor winTpGatherOutGMTensor_; + GlobalTensor fpWinTpGatherOutGMTensor_; + GlobalTensor winTpEpCntGMTensor_; + LocalTensor xTmpTensor_; + LocalTensor tpTmpTensor_; + LocalTensor xInTensor_; + LocalTensor xOutTensor_; + LocalTensor xOutFp32Tensor_; + LocalTensor expertCountTensor_; + LocalTensor expertIdsTensor_; + LocalTensor receivestatusTensor_; + LocalTensor rowMaxTensor_; + LocalTensor statusTensor_; + LocalTensor statusFp32Tensor_; + LocalTensor smoothScalesTensor_; + LocalTensor dynamicScalesTensor_; + TBuf<> dynamicScalesBuf_; + TBuf<> expertCountBuf_; + TBuf<> expertIdsBuf_; + TBuf<> statusBuf_; + TBuf<> gatherMaskOutBuf_; // gather mask输出buf + TBuf<> getTotalBuf_; // 计算totalCnt + TBuf<> scalarBuf_; // 辅助gather tensor定义 + TBuf<> rowMaxBuf_; + TBuf<> receiveDataCastFloatBuf_; + TBuf<> smoothScalesBuf_; + TQueBind xQueue_; // 非量化使用,量化场景接收也可使用 + TQue xInQueue_; // 量化使用,量化前的输入 + TQue xOutQueue_; // 量化使用,量化后的输出 + GM_ADDR expandXOutGM_; + GM_ADDR expandIdxOutGM_; + GM_ADDR expertTokenNumsOutGM_; // 这个输出没有使用 + GM_ADDR sendCountsOutGM_; + GM_ADDR outputRecvCountGM_; + GM_ADDR sendTpCountOutGM_; + GM_ADDR statusSpaceGm_; + GM_ADDR windowGM_; + GM_ADDR tpWindowGM_; + GM_ADDR tpStatusWindowGM_; + GM_ADDR tpLocalWindowGM_; + GM_ADDR tpLocalStatusWindowGM_; + GlobalTensor peerMemsAddrGm_; + // tiling侧已确保数据上限,相乘不会越界,因此统一采用uint32_t进行处理 + uint32_t axisBS_{0}; + uint32_t axisMaxBS_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t aivNum_{0}; + uint32_t sharedUsedAivNum_{0}; + uint32_t moeUsedAivNum_{0}; + uint32_t epWorldSize_{0}; + uint32_t tpWorldSize_{0}; + uint32_t epRankId_{0}; + uint32_t tpGatherRankId_{0}; // gather 对端ID + uint32_t tpRankId_{0}; // 本卡 ID + uint32_t aivId_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; // 共享专家卡数 + uint32_t moeExpertRankNum_{0}; // moe专家卡数,等于worldSize_ - 共享专家卡数 + uint32_t moeExpertNumPerRank_{0}; + uint32_t moeExpertNum_{0}; + uint32_t totalExpertNum_{0}; + uint32_t bufferSizePerRank_{0}; + uint32_t recvWinBlockNum_{0}; + uint32_t hSize_{0}; + uint32_t hOutSize_{0}; + uint32_t hCommuSize_{0}; + uint32_t scaleParamPad_{0}; + uint32_t axisHCommu_{0}; + uint32_t startExpertId_; + uint32_t endExpertId_; + uint32_t sendExpertNum_; + uint32_t localCopyCoreNum_; + uint32_t totalCnt_; + uint32_t lastCore_{0}; + uint32_t dataState_{0}; + uint32_t stateOffset_{0}; + uint64_t winDataSizeOffset_{0}; + uint64_t expertPerSizeOnWin_{0}; + uint64_t windyquantOffset_; + bool isShareExpertRank_ = false; + bool isQuant_ = false; + float sumTarget_; + uint64_t totalWinSize_{0}; + uint32_t gatherCount_{0}; + uint32_t expertTokenNumsType_{1}; + uint32_t preCnt_{0}; + __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; + // 循环优化使用的变量 + TBuf<> sendTableIdsBuf_; + LocalTensor tableLocalTensor_; + LocalTensor sendCountLocalTensor_; + uint32_t moeExpertRankNumAligned_; + uint32_t moeExpertRankNumInt16Aligned_; + uint32_t tableElemCount_; + bool enableAivOpt_{false}; +}; + +template +__aicore__ inline void CamMoeDistributeDispatch::Init( + GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, + GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, + GM_ADDR workspaceGM, TPipe *pipe, const FusedDeepMoeTilingData *tilingData) +{ + tpipe_ = pipe; + aivId_ = GetBlockIdx(); + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + GlobalTensor selfDataStatusTensor; + GM_ADDR statusDataSpaceGm; + + winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + winContext_[COMM_TP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<1>(); // 没有相关公共宏 + + statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp); + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[aivId_ * UB_ALIGN]); + __asm__ __volatile__(""); + dataState_ = selfDataStatusTensor(aivId_ * UB_ALIGN); + if (dataState_ == 0) { + selfDataStatusTensor(aivId_ * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivId_ * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[aivId_ * UB_ALIGN]); + __asm__ __volatile__(""); + pipe_barrier(PIPE_ALL); + axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + // axisMaxBS_ = axisBS_; + axisMaxBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs / epWorldSize_; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + expertTokenNumsType_ = 0; + totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize; + moeExpertRankNum_ = epWorldSize_ - sharedExpertRankNum_; + moeExpertNumPerRank_ = moeExpertNum_ / moeExpertRankNum_; + expertPerSizeOnWin_ = axisMaxBS_ * axisH_ * sizeof(XType); + winDataSizeOffset_ = dataState_ * epWorldSize_ * expertPerSizeOnWin_ * moeExpertNumPerRank_; + tpRankId_ = 0; + windowGM_ = GetWindAddrByRankId(COMM_EP_IDX, epRankId_); + statusSpaceGm_ = GetWindStateAddrByRankId(COMM_EP_IDX, epRankId_); + tpGatherRankId_ = tpRankId_ == 0 ? 1 : 0; + axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + aivNum_ = 48; + tpWorldSize_ = 1; + xGMTensor_.SetGlobalBuffer((__gm__ XType *)x); + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds); + expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut); + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut); + windowInQuantTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)windowGM_); + windowInstatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_)); + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(statusSpaceGm_)); + if constexpr (IsNeedAllgater) { + tpLocalWindowGM_ = GetWindAddrByRankId(COMM_TP_IDX, tpRankId_); + tpLocalStatusWindowGM_ = GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_); + tpWindowGM_ = GetWindAddrByRankId(COMM_TP_IDX, tpGatherRankId_); + tpStatusWindowGM_ = GetWindStateAddrByRankId(COMM_TP_IDX, tpGatherRankId_); + winTpGatherOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)tpWindowGM_); + fpWinTpGatherOutGMTensor_.SetGlobalBuffer((__gm__ float *)tpWindowGM_); + winTpEpCntGMTensor_.SetGlobalBuffer((__gm__ int32_t *)(tpStatusWindowGM_ + TP_STATE_SIZE)); + } + expandXOutGM_ = expandXOut; + expandIdxOutGM_ = expandIdxOut; // 无GlobalTensor + sendCountsOutGM_ = sendCountsOut; // 无GlobalTensor + outputRecvCountGM_ = outputRecvCount; + sendTpCountOutGM_ = tpSendCountsOut; + isQuant_ = StaticQuant | DynamicQuant; + hSize_ = axisH_ * sizeof(XType); + hOutSize_ = axisH_ * sizeof(ExpandXOutType); // 如有量化,需要量化后通信 + scaleParamPad_ = (isQuant_ ? 128 : 0); // 预留128B给量化参数,实际只使用了4B(fp32) + hCommuSize_ = hOutSize_ + scaleParamPad_; + axisHCommu_ = hCommuSize_ / sizeof(ExpandXOutType); + if (sharedExpertRankNum_ != 0) { // 后面的卡才需要发给共享专家发数据 + sharedUsedAivNum_ = aivNum_ / (axisK_ + 1); // 均等分,取整 + if (sharedUsedAivNum_ == 0) { + sharedUsedAivNum_ = 1; + } + } + moeUsedAivNum_ = aivNum_ - sharedUsedAivNum_; + bufferSizePerRank_ = 32 * hSize_; + recvWinBlockNum_ = epWorldSize_ * moeExpertNumPerRank_; + isShareExpertRank_ = (epRankId_ < sharedExpertRankNum_) ? true : false; + windyquantOffset_ = epWorldSize_ * axisMaxBS_ * hOutSize_; + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + DataCacheCleanAndInvalid( + selfStatusTensor[aivId_ * UB_ALIGN]); + int32_t state = selfStatusTensor(aivId_ * UB_ALIGN); + stateOffset_ = (recvWinBlockNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET; + tpipe_->InitBuffer(statusBuf_, recvWinBlockNum_ * UB_ALIGN); // expertNum * 32B + statusTensor_ = statusBuf_.Get(); // 保存发送数据量及flag,同时用于计算windows中的偏移 + Duplicate(statusTensor_, 0, recvWinBlockNum_ * 8); // 8 = UB_ALIGN / sizeof(int32_t) + if (state == 0) { + sumTarget_ = (float)1.0; + selfStatusTensor(aivId_ * UB_ALIGN) = 0x3F800000; + uint64_t mask[2] = {0x101010101010101, 0}; // 一次性操作256字节,也是64个int32_t,每8个数将首个设置为0x3F800000 + Duplicate(statusTensor_, 0x3F800000, mask, recvWinBlockNum_ / 8, 1, 8); // 0x3F800000是float的1 + } else { + sumTarget_ = 0.0; + selfStatusTensor(aivId_ * UB_ALIGN) = 0; + } + DataCacheCleanAndInvalid( + selfStatusTensor[aivId_ * UB_ALIGN]); + tpipe_->InitBuffer(xQueue_, BUFFER_NUM, hCommuSize_); // 14k *2 + if (isQuant_) { + QuantInit(scales); + } + uint32_t expertIdsSize = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; // 约束32对齐 + tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize); // BS * K * 4 + expertIdsTensor_ = expertIdsBuf_.Get(); + tpipe_->InitBuffer(expertCountBuf_, expertIdsSize); // BS * K * 4 + expertCountTensor_ = expertCountBuf_.Get(); + + tpipe_->InitBuffer(gatherMaskOutBuf_, recvWinBlockNum_ * sizeof(float)); // worldsize * 4B + tpipe_->InitBuffer(getTotalBuf_, + epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t)); // worldsize * 单卡专家数 * 4B + tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2); // 72B + + moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK; + if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK && + axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) { + // UB空间限制BS不大于64、路由专家数量不大于256;对齐要求限制axisK_是8的倍数 + enableAivOpt_ = true; + moeExpertRankNumInt16Aligned_ = moeExpertRankNumAligned_ / 2; // 每个int16_t装2个uint8_t + tableElemCount_ = (axisBS_ + 1) * moeExpertRankNumAligned_; // 额外加一行(首行全0) + + tpipe_->InitBuffer(sendTableIdsBuf_, tableElemCount_ * sizeof(countType)); + tableLocalTensor_ = sendTableIdsBuf_.Get(); + sendCountLocalTensor_ = tableLocalTensor_[axisBS_ * moeExpertRankNumAligned_]; // 计算完成后,最后一行为count + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::QuantInit(GM_ADDR scales) +{ + tpipe_->InitBuffer(xInQueue_, BUFFER_NUM, hSize_); // 14K *2 + tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, hCommuSize_); // 7K *2 + scalesGMTensor_.SetGlobalBuffer((__gm__ float *)scales); + uint32_t hFp32Size = axisH_ * sizeof(float); + if constexpr (DynamicQuant) { + tpipe_->InitBuffer(rowMaxBuf_, UB_ALIGN); // 32B + } + tpipe_->InitBuffer(receiveDataCastFloatBuf_, 1 * hFp32Size); // 28KB + tpipe_->InitBuffer(smoothScalesBuf_, axisH_ * sizeof(float)); // 28KB + smoothScalesTensor_ = smoothScalesBuf_.Get(); + tpipe_->InitBuffer(dynamicScalesBuf_, axisBS_ * sizeof(float)); // 32 * 4 + dynamicScalesTensor_ = dynamicScalesBuf_.Get(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SendToSharedExpert() +{ + uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_; // 每个aiv需要发送的token数 + uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_; // 余数 + uint32_t newAivId = aivId_ - moeUsedAivNum_; // 由于是后面的核作为发送的共享专家,因此需要换算 + uint32_t startTokenId = sendTokenNum * newAivId; // 每个aiv发送时的起始rankid + if (newAivId < remainderTokenNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + if (startTokenId >= axisBS_) { + return; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) { + uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum); + uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; // dst + uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - + epRankId_ * axisBS_ / sharedExpertRankNum_; // 发给该共享专家已经有多少token数据 + GlobalTensor dstWinGMTensor; + dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) + + expertPerSizeOnWin_ * epRankId_)); + if constexpr (DynamicQuant || StaticQuant) { + xInTensor_ = xInQueue_.AllocTensor(); + DataCopy(xInTensor_, xGMTensor_[tokenIndex * axisH_], axisH_); // 约束对齐 + xInQueue_.EnQue(xInTensor_); + xInTensor_ = xInQueue_.DeQue(); + xOutTensor_ = xOutQueue_.AllocTensor(); + QuantProcess(0); + xOutQueue_.EnQue(xOutTensor_); + + xOutTensor_ = xOutQueue_.DeQue(); + if (isShareExpertRank_) { + xOutFp32Tensor_ = xOutTensor_.template ReinterpretCast(); + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + DataCopyPad(dynamicScalesOutGMTensor_[tokenIndex], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenIndex * axisHCommu_], xOutTensor_, axisHCommu_); // 约束对齐 + } + DataCopy(expandXOutGMTensor_[tokenIndex * axisH_], xOutTensor_, axisH_); // 约束对齐 + } else { + DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu_], xOutTensor_, axisHCommu_); // 约束对齐 + } + xOutQueue_.FreeTensor(xOutTensor_); + } else { + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, xGMTensor_[tokenIndex * axisH_], axisH_); // 约束对齐 + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if (isShareExpertRank_) { + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenIndex * axisHCommu_], xTmpTensor_, axisHCommu_); + } + DataCopy(expandXOutGMTensor_[tokenIndex * axisHCommu_], xTmpTensor_, axisHCommu_); + } else { + DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu_], xTmpTensor_, axisHCommu_); // 约束对齐 + } + xQueue_.FreeTensor(xTmpTensor_); + } + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SendToMoeExpert() +{ + uint32_t expertIdsCnt = axisBS_ * axisK_; + uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; // 每个aiv需要发送的token数 + uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; // 余数 + uint32_t startTokenId = sendTokenNum * aivId_; // 每个aiv发送时的起始rankid + if (aivId_ < remainderTokenNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendTokenNum += 1; + startTokenId += aivId_; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + GlobalTensor dstWinGMTensor; + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + int32_t dstExpertId = expertIdsTensor_(tokenIndex); + if (dstExpertId < 0) { + continue; + } + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank_ + sharedExpertRankNum_; + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindAddrByRankId(COMM_EP_IDX, tempRankId) + + (expertPerSizeOnWin_ * + (epRankId_ * moeExpertNumPerRank_ + dstExpertId % moeExpertNumPerRank_)) + + hCommuSize_ * expertCountTensor_(tokenIndex)); // 计算地址偏移 + dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)rankGM); + if constexpr (DynamicQuant || StaticQuant) { + xInTensor_ = xInQueue_.AllocTensor(); + DataCopy(xInTensor_, xGMTensor_[tokenIndex / axisK_ * axisH_], axisH_); // 约束对齐 + xInQueue_.EnQue(xInTensor_); + xInTensor_ = xInQueue_.DeQue(); + xOutTensor_ = xOutQueue_.AllocTensor(); + uint32_t expertIndex = sharedExpertRankNum_ != 0 ? (dstExpertId + 1) : dstExpertId; + QuantProcess(expertIndex); + xOutQueue_.EnQue(xOutTensor_); + + xOutTensor_ = xOutQueue_.DeQue(); + DataCopy(dstWinGMTensor, xOutTensor_, axisHCommu_); // 约束对齐 + xOutQueue_.FreeTensor(xOutTensor_); + } else { + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, xGMTensor_[tokenIndex / axisK_ * axisH_], axisH_); // 约束对齐 + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + DataCopy(dstWinGMTensor, xTmpTensor_, axisHCommu_); // 约束对齐 + xQueue_.FreeTensor(xTmpTensor_); + } + } + if (aivId_ == (moeUsedAivNum_ - 1) && (!enableAivOpt_)) { + // 不启用循环优化时,这里才需要写出结果 + GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)expandIdxOutGM_); + DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; + DataCopyPad(expandIdxGMTensor, expertCountTensor_, expertIdsCntParams); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AlltoAllDispatch() +{ + uint32_t expertIdsCnt = axisBS_ * axisK_; + DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::TQueSync expertCntLocalSync; + expertCntLocalSync.SetFlag(0); + expertCntLocalSync.WaitFlag(0); + if (enableAivOpt_) { + LocalTensor tableInt16LocalTensor_ = tableLocalTensor_.template ReinterpretCast(); + Duplicate(tableInt16LocalTensor_, (int16_t)0, tableElemCount_ / 2); // 清零 + SyncFunc(); + for (int tokenIndex = 0; tokenIndex < expertIdsCnt; ++tokenIndex) { // 填表。默认为0,发送置1 + int expertId = expertIdsTensor_(tokenIndex); + if (expertId < 0) { + continue; + } + tableLocalTensor_((tokenIndex / axisK_ + 1) * moeExpertRankNumAligned_ + expertId) = 1; + } + pipe_barrier(PIPE_ALL); + + // 分核,确定每个核要处理的token + uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; + uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; + uint32_t startTokenId = sendTokenNum * aivId_; + if (aivId_ < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += aivId_; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + uint32_t startTokenRow = startTokenId / axisK_; + uint32_t endTokenRow = (endTokenId + axisK_ - 1) / axisK_; + + for (int row = 1; row <= axisBS_; ++row) { + Add(tableInt16LocalTensor_[row * moeExpertRankNumInt16Aligned_], + tableInt16LocalTensor_[row * moeExpertRankNumInt16Aligned_], + tableInt16LocalTensor_[(row - 1) * moeExpertRankNumInt16Aligned_], moeExpertRankNumInt16Aligned_); + pipe_barrier(PIPE_V); + } + + // 计算完成后,下标为的i的行为下标为i+1的token在远端的偏移,最后一行为总count + GlobalTensor expandIdxGMTensor; + if (aivId_ < moeUsedAivNum_) { + SyncFunc(); + for (int row = startTokenRow; row < endTokenRow; ++row) { + for (int expertIndex = 0; expertIndex < axisK_; ++expertIndex) { + int32_t expertId = expertIdsTensor_(row * axisK_ + expertIndex); + if (expertId < 0) { + continue; + } + expertCountTensor_(row * axisK_ + expertIndex) = + (int32_t)tableLocalTensor_(row * moeExpertRankNumAligned_ + expertId); + } + SyncFunc(); + expandIdxGMTensor.SetGlobalBuffer( + (__gm__ int32_t *)(expandIdxOutGM_ + row * axisK_ * sizeof(uint32_t))); + DataCopy(expandIdxGMTensor, expertCountTensor_[row * axisK_], axisK_); + } + } + + // 分核,确定每个核要set status的rank + uint32_t preTotalExpertNum = sharedExpertRankNum_ + moeExpertNum_; + uint32_t preSendExpertNum = preTotalExpertNum / aivNum_; + uint32_t preRemainderRankNum = preTotalExpertNum % aivNum_; + uint32_t preStartExpertId = preSendExpertNum * aivId_; + if (aivId_ < preRemainderRankNum) { + preSendExpertNum += 1; + preStartExpertId += aivId_; + } else { + preStartExpertId += preRemainderRankNum; + } + uint32_t preEndExpertId = preStartExpertId + preSendExpertNum; + preStartExpertId = preStartExpertId >= sharedExpertRankNum_ ? preStartExpertId : sharedExpertRankNum_; + + SyncFunc(); + for (int32_t tmpExpertId = preStartExpertId; tmpExpertId < preEndExpertId; ++tmpExpertId) { + statusTensor_(tmpExpertId * INT32_NUM_PER_BLOCK + 1) = + (int32_t)sendCountLocalTensor_(tmpExpertId - sharedExpertRankNum_); + } + } else { + for (uint32_t tokenIndex = 0; tokenIndex < expertIdsCnt; ++tokenIndex) { + // 防止越界,越界判断(expertId >= epWorldSize_) || (expertId < sharedExpertRankNum_) + int32_t expertId = expertIdsTensor_(tokenIndex) + sharedExpertRankNum_; + if (expertId < 0) { + continue; + } + expertCountTensor_(tokenIndex) = statusTensor_(expertId * INT32_NUM_PER_BLOCK + 1); + statusTensor_(expertId * INT32_NUM_PER_BLOCK + 1)++; + } + } + if (!isShareExpertRank_) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum_; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_ - + (curSatatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_; + statusTensor_((curSatatusExpId)*INT32_NUM_PER_BLOCK + 1) = curExpertCnt; + } + } + if ((sharedExpertRankNum_ != 0) && (aivId_ >= moeUsedAivNum_)) { // 后面的核进行发给共享专家 + SendToSharedExpert(); + return; + } + SendToMoeExpert(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SetStatus() +{ + pipe_barrier(PIPE_ALL); + SyncAll(); + totalExpertNum_ = sharedExpertRankNum_ + moeExpertNum_; + sendExpertNum_ = totalExpertNum_ / aivNum_; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = totalExpertNum_ % aivNum_; + startExpertId_ = sendExpertNum_ * aivId_; // + sharedExpertRankNum_, 每个aiv发送的起始rankid + if (aivId_ < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendExpertNum_ += 1; + startExpertId_ += aivId_; + } else { + startExpertId_ += remainderRankNum; + } + endExpertId_ = startExpertId_ + sendExpertNum_; + if (startExpertId_ >= totalExpertNum_) { // 多余的核return + return; + } + GlobalTensor rankGMTensor; + uint32_t offset = stateOffset_ * epRankId_; + for (uint32_t rankIndex = startExpertId_; rankIndex < endExpertId_; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank_ > 1 && (rankIndex >= sharedExpertRankNum_)) { + dstRankId = ((rankIndex - sharedExpertRankNum_) / moeExpertNumPerRank_ + sharedExpertRankNum_); + offset = + (epRankId_ + (rankIndex - sharedExpertRankNum_) % moeExpertNumPerRank_ * epWorldSize_) * stateOffset_; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_EP_IDX, dstRankId) + offset); // 计算地址偏移 + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); // 8时数据大小,按32对齐拷贝 + } + SyncFunc(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::QuantProcess(uint32_t expertIndex) +{ + float dynamicScale = 0.0; + LocalTensor floatLocalTemp; + floatLocalTemp = receiveDataCastFloatBuf_.Get(); + Cast(floatLocalTemp, xInTensor_, RoundMode::CAST_NONE, axisH_); + xInQueue_.FreeTensor(xInTensor_); + pipe_barrier(PIPE_V); + if constexpr (IsSmoothScaleExist) { + if constexpr (DynamicQuant) { + SyncFunc(); // ub复用,循环同步 + } + DataCopy(smoothScalesTensor_, scalesGMTensor_[expertIndex * axisH_], axisH_); + SyncFunc(); + Mul(floatLocalTemp, floatLocalTemp, smoothScalesTensor_, axisH_); + pipe_barrier(PIPE_V); + } + if constexpr (DynamicQuant) { + LocalTensor floatLocalAbsTemp = smoothScalesBuf_.Get(); + rowMaxTensor_ = rowMaxBuf_.Get(); + Abs(floatLocalAbsTemp, floatLocalTemp, axisH_); + pipe_barrier(PIPE_V); + ReduceMax(rowMaxTensor_, floatLocalAbsTemp, floatLocalAbsTemp, axisH_, false); + SyncFunc(); + dynamicScale = float(127.0) / rowMaxTensor_.GetValue(0); + SyncFunc(); + Muls(floatLocalTemp, floatLocalTemp, dynamicScale, axisH_); + pipe_barrier(PIPE_V); + } + LocalTensor halfLocalTemp = floatLocalTemp.ReinterpretCast(); + LocalTensor int32LocalTemp = floatLocalTemp.ReinterpretCast(); + Cast(int32LocalTemp, floatLocalTemp, RoundMode::CAST_RINT, axisH_); + pipe_barrier(PIPE_V); + SetDeqScale((half)1.000000e+00f); + PipeBarrier(); + Cast(halfLocalTemp, int32LocalTemp, RoundMode::CAST_ROUND, axisH_); + pipe_barrier(PIPE_V); + Cast(xOutTensor_, halfLocalTemp, RoundMode::CAST_TRUNC, axisH_); + floatLocalTemp = xOutTensor_.template ReinterpretCast(); + floatLocalTemp.SetValue(axisH_ / sizeof(float), float(1.0) / dynamicScale); // int8->float32 +} + +template +__aicore__ inline void CamMoeDistributeDispatch::LocalSharedExpertCopyWindow( + uint32_t rankIndex, uint32_t tokenOffset, uint32_t currendTokenIndex, uint32_t &dynamicScalesLocalIdx) +{ + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, + windowInQuantTensor_[rankIndex * (expertPerSizeOnWin_ / sizeof(ExpandXOutType)) + + currendTokenIndex * axisHCommu_], + axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if constexpr (DynamicQuant || StaticQuant) { + pipe_barrier(PIPE_ALL); + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + dynamicScalesTensor_.SetValue(dynamicScalesLocalIdx++, xOutFp32Tensor_.GetValue(axisH_ / sizeof(float))); + pipe_barrier(PIPE_ALL); + } + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenOffset * axisH_], xTmpTensor_, axisH_); + } + DataCopy(expandXOutGMTensor_[tokenOffset * axisH_], xTmpTensor_, axisH_); + xQueue_.FreeTensor(xTmpTensor_); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::WaitDispatch() +{ + uint32_t rscvStatusNum = isShareExpertRank_ ? epWorldSize_ : recvWinBlockNum_; + uint32_t recStatusNumPerCore = rscvStatusNum / aivNum_; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = rscvStatusNum % aivNum_; + uint32_t startStatusIndex = recStatusNumPerCore * aivId_; // + sharedExpertRankNum_, 每个aiv发送的起始rankid + if (aivId_ < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + recStatusNumPerCore += 1; + startStatusIndex += aivId_; + } else { + startStatusIndex += remainderRankNum; + } + if (startStatusIndex >= rscvStatusNum) { + SyncAll(); + return; + } + LocalTensor gatherMaskOutTensor = gatherMaskOutBuf_.Get(); + LocalTensor gatherTmpTensor = scalarBuf_.GetWithOffset(UB_ALIGN / sizeof(uint32_t), 0); + gatherTmpTensor.SetValue(0, 1); + LocalTensor statusSumOutTensor = scalarBuf_.GetWithOffset(UB_ALIGN / sizeof(float), UB_ALIGN); + statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + uint32_t mask = 1; // gatherMask + sum 相关参数 + uint64_t rsvdCnt = 0; + SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget_ * recStatusNumPerCore) + (float)0.5; + DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, + static_cast((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block + SyncFunc(); + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset_ / sizeof(float)], + intriParams); + SyncFunc(); + GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + pipe_barrier(PIPE_V); + Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + SyncFunc(); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + SyncAll(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::GetCumSum(LocalTensor &inLocal, + LocalTensor &outLocal, + int32_t totalCount, + GM_ADDR gmOutputRecvCount) +{ + statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + DataCopyParams intriParams{static_cast(recvWinBlockNum_), 1, + static_cast((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block + DataCopy(statusTensor_, windowInstatusTensor_, intriParams); + SyncFunc(); + if (isShareExpertRank_) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum_; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_ - + (curSatatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_; + statusTensor_((curSatatusExpId)*INT32_NUM_PER_BLOCK + 1) = curExpertCnt; + } + } + outLocal = gatherMaskOutBuf_.Get(); // 内存复用 + LocalTensor getTotalLocal = getTotalBuf_.Get(); + // gather mask在一起 + TBuf<> gatherTmpBuf; + TBuf<> workLocalBuf; + tpipe_->InitBuffer(gatherTmpBuf, sizeof(uint32_t) * recvWinBlockNum_ / 4); + LocalTensor gatherTmpTensor = gatherTmpBuf.Get(); + Duplicate(gatherTmpTensor, (uint32_t)33686018, recvWinBlockNum_ / 4); // 0000 0010 0000 0010 0000 0010 0000 0010 + PipeBarrier(); + uint32_t mask = recvWinBlockNum_ * 8; // 512 / 32 + uint64_t rsvdCnt = 0; + GatherMask(outLocal, inLocal, gatherTmpTensor, true, mask, {1, 1, 0, 0}, rsvdCnt); + AscendC::GlobalTensor recvCountTensor; + recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount); + uint32_t localExpertNum = isShareExpertRank_ ? 1 : moeExpertNumPerRank_; + AscendC::DataCopyExtParams dataCopyParams = { + 1U, static_cast(localExpertNum * epWorldSize_ * sizeof(int32_t)), 0U, 0U, 0U}; + SyncFunc(); + AscendC::DataCopyPad(recvCountTensor, outLocal.ReinterpretCast(), dataCopyParams); + SyncFunc(); + // 再用cumsum累加,按照列相加 + int typeSize = sizeof(int32_t); + int32_t elementsPerBlock = 32 / typeSize; + int32_t elementsPerRepeat = 256 / typeSize; + int32_t firstMaxRepeat = epWorldSize_; + int32_t iter1OutputCount = firstMaxRepeat; + int32_t iter1AlignEnd = ((iter1OutputCount + elementsPerBlock - 1) / elementsPerBlock) * elementsPerBlock; + int32_t finalWorkLocalNeedSize = iter1AlignEnd; + tpipe_->InitBuffer(workLocalBuf, finalWorkLocalNeedSize * sizeof(int32_t)); + LocalTensor workLocalTensor = workLocalBuf.Get(); + LocalTensor tmpFp32 = outLocal.ReinterpretCast(); + PipeBarrier(); + ReduceSum(getTotalLocal, tmpFp32, workLocalTensor, epWorldSize_); + totalCnt_ = getTotalLocal.ReinterpretCast().GetValue(0); + PipeBarrier(); + ReduceSum(tmpFp32, tmpFp32, workLocalTensor, totalCount); + PipeBarrier(); +} + +template +__aicore__ inline void +CamMoeDistributeDispatch::CreateZeroTensor(LocalTensor &outLocal) +{ + TBuf<> outBuf; + tpipe_->InitBuffer(outBuf, UB_ALIGN); + outLocal = outBuf.Get(); + for (uint32_t i = 0; i < 2; i++) { + outLocal.SetValue(i, 0); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::LocalWindowCopy() +{ + uint32_t totalMoeExpert = 0; + LocalTensor outCountLocal; + if (isShareExpertRank_) { + totalMoeExpert = epWorldSize_; + } else { + totalMoeExpert = epWorldSize_ * moeExpertNumPerRank_; + } + sendExpertNum_ = totalMoeExpert / aivNum_; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = totalMoeExpert % aivNum_; + startExpertId_ = sendExpertNum_ * aivId_; // + sharedExpertRankNum_, 每个aiv发送的起始rankid + if (aivId_ < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendExpertNum_ += 1; + startExpertId_ += aivId_; + } else { + startExpertId_ += remainderRankNum; + } + endExpertId_ = startExpertId_ + sendExpertNum_; + if (startExpertId_ >= totalMoeExpert) { // 多余的核return + return; + } + GetCumSum(statusTensor_, outCountLocal, startExpertId_ + 1, outputRecvCountGM_); + uint32_t index = 0; + uint32_t beginIdx = 0; + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + for (uint32_t index = startExpertId_; index < endExpertId_; index++) { + uint32_t i = index - startExpertId_; + if (i > 0) { + outCountLocal.SetValue(i, outCountLocal.GetValue(i - 1) + outCountLocal.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_NUM_PER_BLOCK + 1); + beginIdx = outCountLocal.GetValue(i) - count; + if constexpr (IsNeedAllgater) { + gatherCount_ += count; + } + if (i == 0) { + preCnt_ = beginIdx; + } + if (isShareExpertRank_) { + if (index < sharedExpertRankNum_) { // 共享专家前面排布的是本卡数据,只需要统计epRecvCnt,不需要去搬出 + beginIdx += count; + continue; + } + } + uint32_t winOffset = index; + if (!isShareExpertRank_) { + if (moeExpertNumPerRank_ > 1) { + winOffset = + index % epWorldSize_ * moeExpertNumPerRank_ + index / epWorldSize_; // 转换成数据区的排布偏移 + } + } + GM_ADDR wAddr = (__gm__ uint8_t *)(windowGM_) + winOffset * expertPerSizeOnWin_; + GlobalTensor tokGlobal; + GlobalTensor expandXOutGlobal; + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(wAddr + j * hCommuSize_)); + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, tokGlobal, axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if constexpr (DynamicQuant || StaticQuant) { + pipe_barrier(PIPE_ALL); + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + DataCopyPad(dynamicScalesOutGMTensor_[beginIdx + j], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + pipe_barrier(PIPE_ALL); + } + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[(beginIdx + j) * axisHCommu_], xTmpTensor_, axisHCommu_); + } + expandXOutGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM_) + (beginIdx + j) * axisH_, + axisH_); + DataCopy(expandXOutGlobal, xTmpTensor_, axisH_); + xQueue_.FreeTensor(xTmpTensor_); + } + beginIdx += count; + } + if constexpr (!IsNeedAllgater) { + totalCnt_ = beginIdx; + } + lastCore_ = MIN(totalMoeExpert, aivNum_) - 1; + if constexpr (IsNeedAllgater) { + DataCopyExtParams dataCopyOutParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPad(winTpEpCntGMTensor_[startExpertId_], outCountLocal, dataCopyOutParams); + } + DataCopyExtParams dataCopyOutParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + DataCopyPad(sendCountsGlobal[startExpertId_], outCountLocal, dataCopyOutParams); + PipeBarrier(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AllGatherSetStatusAndWait() +{ + pipe_barrier(PIPE_ALL); + if (startExpertId_ >= totalExpertNum_) { + return; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpGatherRankId_) + stateOffset_ * aivId_); + GlobalTensor tpwindowInstatusFp32Tensor_; + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(rankGM)); + statusTensor_(aivId_ * INT32_NUM_PER_BLOCK + 1) = gatherCount_; + statusTensor_(aivId_ * INT32_NUM_PER_BLOCK + 2) = preCnt_; + LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + statusFp32Tensor_(aivId_ * 8) = sumTarget_; + SyncFunc(); + DataCopy(tpwindowInstatusFp32Tensor_, statusFp32Tensor_[aivId_ * 8], + UB_ALIGN); // 12是数据大小,按32对齐拷贝 + SyncFunc(); + float sumOfFlag = static_cast(-1.0); + rankGM = + (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_) + stateOffset_ * aivId_); // 计算地址偏移 + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(rankGM)); + while (sumOfFlag != sumTarget_) { + DataCopy(statusFp32Tensor_, tpwindowInstatusFp32Tensor_, UB_ALIGN); + SyncFunc(); + sumOfFlag = statusFp32Tensor_.GetValue(0); + SyncFunc(); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AllgatherProcessOut() +{ + if (startExpertId_ >= totalExpertNum_) { + return; + } + // 获取需要allgather的tokens数量 + GlobalTensor tpwindowInstatusFp32Tensor_; + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_) + stateOffset_ * aivId_); + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)rankGM); + LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + DataCopy(statusFp32Tensor_, tpwindowInstatusFp32Tensor_, UB_ALIGN); + SyncFunc(); + uint32_t coreGatherCount = statusFp32Tensor_.ReinterpretCast().GetValue(1); + uint32_t preCount = statusFp32Tensor_.ReinterpretCast().GetValue(2); + gatherCount_ = coreGatherCount; + preCnt_ = preCount; + GlobalTensor sendCountsGlobal; + GlobalTensor tpGlobal; + // 搬运另一个tp域卡传来的epRcvCnt + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + tpGlobal.SetGlobalBuffer((__gm__ int32_t *)(tpLocalStatusWindowGM_ + TP_STATE_SIZE)); + DataCopyExtParams dataCopyParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + tpTmpTensor_ = xQueue_.AllocTensor(); + DataCopyPad(tpTmpTensor_, tpGlobal[startExpertId_], dataCopyParams, copyPadParams); + xQueue_.EnQue(tpTmpTensor_); + tpTmpTensor_ = xQueue_.DeQue(); + DataCopyPad(sendCountsGlobal[epWorldSize_ + startExpertId_], tpTmpTensor_, dataCopyParams); + xQueue_.FreeTensor(tpTmpTensor_); + if (coreGatherCount == 0) { + return; + } + // 输出起始偏移本卡数据 + GlobalTensor tokGlobal; + GlobalTensor expandXOutGlobal; + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + for (uint32_t i = 0; i < coreGatherCount; i++) { + tokGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(tpLocalWindowGM_ + (preCount + i) * hCommuSize_)); + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, tokGlobal, axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + expandXOutGlobal.SetGlobalBuffer( + (__gm__ ExpandXOutType *)(expandXOutGM_ + (preCount + totalCnt_ + i) * hOutSize_)); + DataCopy(expandXOutGlobal, xTmpTensor_, axisH_); + if constexpr (StaticQuant || DynamicQuant) { + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + DataCopyPad(dynamicScalesOutGMTensor_[preCount + totalCnt_ + i], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + } + xQueue_.FreeTensor(xTmpTensor_); + } +} + +// 更新多专家卡上的tokenNumsOut tensor +template +__aicore__ inline void CamMoeDistributeDispatch::UpdataMultiMoeTokenNumsOut() +{ + uint32_t tokenSums = 0; + GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + for (uint32_t localMoeIndex = 0; localMoeIndex < moeExpertNumPerRank_; ++localMoeIndex) { + if (localMoeIndex == 0) { + DataCacheCleanAndInvalid( + sendCountsGlobal[epWorldSize_ - 1]); + uint32_t firstMoeCnt = sendCountsGlobal.GetValue(epWorldSize_ - 1); + tokenSums = firstMoeCnt + gatherCount_; + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + } else { + uint32_t preIndex = epWorldSize_ * (localMoeIndex - 1) + epWorldSize_ - 1; + uint32_t curIndex = epWorldSize_ * localMoeIndex + epWorldSize_ - 1; + DataCacheCleanAndInvalid( + sendCountsGlobal[preIndex]); + DataCacheCleanAndInvalid( + sendCountsGlobal[curIndex]); + uint32_t preMoeIndexCnt = sendCountsGlobal.GetValue(preIndex); + uint32_t curMoeIndexCnt = sendCountsGlobal.GetValue(curIndex); + tokenSums = + ((expertTokenNumsType_ == 0) ? tokenSums : 0) + (curMoeIndexCnt - preMoeIndexCnt) + gatherCount_; + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + } + } +} + +// 更新tokenNumsOut tensor +template +__aicore__ inline void CamMoeDistributeDispatch::UpdataTokenNumsOut() +{ + // 最后一个核做更新,Moe专家只有最后一个核有计算出所有 sendCountsGlobal + if (!isShareExpertRank_ && moeExpertNumPerRank_ > 1) { + SyncAll(); + if (aivId_ != lastCore_) return; + SyncFunc(); + UpdataMultiMoeTokenNumsOut(); + } else { + if (aivId_ != lastCore_) return; + uint32_t tokenNum = 0; + // Moe专家token总数在Cumsum内计算得出 + tokenNum = totalCnt_; + if constexpr (IsNeedAllgater) { + tokenNum += preCnt_; + tokenNum += gatherCount_; + } + expertTokenNumsOutGMTensor_.SetValue(0, tokenNum); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_); + } + // token总数 = 其他专家搬进来的token数 + allgather拿到的另一张卡token数 + if constexpr (IsNeedAllgater) { + GlobalTensor sendTpCountsGlobal; + sendTpCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendTpCountOutGM_)); + sendTpCountsGlobal.SetValue(tpRankId_, totalCnt_); + sendTpCountsGlobal.SetValue(tpGatherRankId_, gatherCount_ + preCnt_); + DataCacheCleanAndInvalid( + sendTpCountsGlobal); // 当前tpId只会为0或1,只需要刷一次Cache + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::Process() +{ + if ASCEND_IS_AIV { // 全aiv处理 + AlltoAllDispatch(); + SetStatus(); + WaitDispatch(); + LocalWindowCopy(); + if constexpr (IsNeedAllgater) { + AllGatherSetStatusAndWait(); + AllgatherProcessOut(); + } + UpdataTokenNumsOut(); + } +} + +} // namespace MoeDistributeDispatchImpl +#endif // CAM_MOE_DISTRIBUTE_DISPATCH_H diff --git a/csrc/utils/op_kernel/operator/catlass/act/act.hpp b/csrc/utils/op_kernel/operator/catlass/act/act.hpp new file mode 100644 index 00000000000..2e5fab8bac2 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/act.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ACT_HPP +#define ACT_ACT_HPP + +#include + +#include "../act/detail/alignment.hpp" +#include "../act/detail/dependent_false.hpp" +#include "../act/detail/macros.hpp" + +namespace Act { + +constexpr uint32_t BYTE_PER_C0 = 32; +constexpr uint32_t C0_NUM_PER_FRACTAL = 16; +constexpr uint32_t BYTE_PER_FRACTAL = BYTE_PER_C0 * C0_NUM_PER_FRACTAL; + +constexpr uint32_t BYTE_PER_BLK = 32; +constexpr uint32_t BLK_NUM_PER_VECTOR_FRACTAL = 8; +constexpr uint32_t BYTE_PER_VECTOR_FRACTAL = BYTE_PER_BLK * BLK_NUM_PER_VECTOR_FRACTAL; + +constexpr uint64_t L2_OFFSET = 0; +constexpr uint32_t STRIDE_LIMIT = 65536; + +} // namespace Act + +#endif // ACT_ACT_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/arch.hpp b/csrc/utils/op_kernel/operator/catlass/act/arch/arch.hpp new file mode 100644 index 00000000000..f1bb8727acb --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/arch/arch.hpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ARCH_ARCH_HPP +#define ACT_ARCH_ARCH_HPP + +namespace Act::Arch { + +struct AtlasA2 { + static constexpr uint32_t BIAS_SIZE = 1024; + static constexpr uint32_t FIXBUF_SIZE = 7 * 1024; + static constexpr uint32_t UB_SIZE = 192 * 1024; + static constexpr uint32_t L1_SIZE = 512 * 1024; + static constexpr uint32_t L0A_SIZE = 64 * 1024; + static constexpr uint32_t L0B_SIZE = 64 * 1024; + static constexpr uint32_t L0C_SIZE = 128 * 1024; +}; + +struct PositionGM { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::GM; +}; + +struct PositionL1 { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A1; +}; + +struct PositionL0A { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A2; +}; + +struct PositionL0B { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::B2; +}; + +struct PositionL0C { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::CO1; +}; + +struct PositionUB { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::VECCALC; +}; + +} // namespace Act::Arch + +#endif // ACT_ARCH_ARCH_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp b/csrc/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp new file mode 100644 index 00000000000..72099c4e481 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ARCH_CROSS_CORE_SYNC_HPP +#define ACT_ARCH_CROSS_CORE_SYNC_HPP + +#include "../../act/act.hpp" + +namespace Act::Arch { + +constexpr uint32_t MAX_REVERSE_DEPTH = 16; + +using FlagID = uint16_t; +constexpr FlagID AIV_INTER_BLOCK_BARRIER = 8; +constexpr FlagID AIC_INTER_BLOCK_BARRIER = 9; +constexpr FlagID AIV_INTER_SUBBLOCK_BARRIER = 10; +constexpr FlagID FFTS_MAX_FLAG = 7; + +struct CrossCoreFlag { + ACT_DEVICE + CrossCoreFlag() : id(0) {} + + ACT_DEVICE + CrossCoreFlag(FlagID id) : id(id) {} + + FlagID id; +}; + +template +struct CrossCoreFlagWithReverse { + ACT_DEVICE + CrossCoreFlagWithReverse() : id(0), reverseId(0) {} + + ACT_DEVICE + CrossCoreFlagWithReverse(FlagID id, FlagID reverseId) : id(id), reverseId(reverseId) {} + + FlagID id; + FlagID reverseId; + uint32_t count{0}; +}; + +template +struct BarrierFlag { + static_assert(MODE != MODE, + "Unsupporteded cross core barrier flag, can not " + "find the specialization."); +}; + +template <> +struct BarrierFlag<0x0, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_BLOCK_BARRIER; +}; + +template <> +struct BarrierFlag<0x0, AscendC::AIC> { + static constexpr FlagID ID = AIC_INTER_BLOCK_BARRIER; +}; + +template <> +struct BarrierFlag<0x1, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_SUBBLOCK_BARRIER; +}; + +template +ACT_DEVICE void CrossCoreBarrier() +{ + constexpr FlagID flagId = BarrierFlag::ID; + AscendC::CrossCoreSetFlag(flagId); + AscendC::CrossCoreWaitFlag(flagId); +} + +template +ACT_DEVICE void CrossCoreSetFlag(CrossCoreFlag &flag) +{ + AscendC::CrossCoreSetFlag(flag.id); +} + +ACT_DEVICE +void CrossCoreWaitFlag(CrossCoreFlag &flag) +{ + AscendC::CrossCoreWaitFlag(flag.id); +} + +template +ACT_DEVICE void CrossCoreSetFlagWithReverse(CrossCoreFlagWithReverse &flag) +{ + AscendC::CrossCoreSetFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreWaitFlag(flag.reverseId); + flag.count = 0; + } +} + +template +ACT_DEVICE void CrossCoreWaitFlagWithReverse(CrossCoreFlagWithReverse &flag) +{ + AscendC::CrossCoreWaitFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreSetFlag(flag.reverseId); + flag.count = 0; + } +} + +} // namespace Act::Arch + +#endif // ACT_ARCH_CROSS_CORE_SYNC_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp b/csrc/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp new file mode 100644 index 00000000000..5208153f5e6 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef INCLUDE_ACT_ARCH_MEMORY_H +#define INCLUDE_ACT_ARCH_MEMORY_H + +#include "../../act/act.hpp" +#include "../../act/arch/arch.hpp" + +namespace Act::Arch { + +struct LocalTensorBufferBase { +public: + template + ACT_DEVICE AscendC::LocalTensor GetBufferByByte(const uint32_t offset) const + { + return tensor[offset].template ReinterpretCast(); + } + +protected: + ACT_DEVICE + LocalTensorBufferBase() = default; + + AscendC::LocalTensor tensor; +}; + +template +struct LocalTensorBuffer { + static_assert(DEPENDENT_FALSE, "Unsupporteded local tensor buffer, can not find the specialization."); +}; + +/// Partial specialization for TPosition::A1 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::A1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufA1; + GetTPipePtr()->InitBuffer(tbufA1, ArchTag::L1_SIZE); + tensor = tbufA1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::A2 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::A2; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufA2; + GetTPipePtr()->InitBuffer(tbufA2, ArchTag::L0A_SIZE); + tensor = tbufA2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::B1 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::B1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufB1; + GetTPipePtr()->InitBuffer(tbufB1, ArchTag::L1_SIZE); + tensor = tbufB1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::B2 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::B2; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufB2; + GetTPipePtr()->InitBuffer(tbufB2, ArchTag::L0B_SIZE); + tensor = tbufB2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C1 +template <> +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC1; + GetTPipePtr()->InitBuffer(tbufC1, ArchTag::L1_SIZE); + tensor = tbufC1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C2 +template <> +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC2; + GetTPipePtr()->InitBuffer(tbufC2, ArchTag::BIAS_SIZE); + tensor = tbufC2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::CO1 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::CO1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufCO1; + GetTPipePtr()->InitBuffer(tbufCO1, ArchTag::L0C_SIZE); + tensor = tbufCO1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C2PIPE2GM +template <> +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2PIPE2GM; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC2PIPE2GM; + GetTPipePtr()->InitBuffer(tbufC2PIPE2GM, ArchTag::FIXBUF_SIZE); + tensor = tbufC2PIPE2GM.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECIN +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECIN; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECIN; + GetTPipePtr()->InitBuffer(tbufVECIN, ArchTag::UB_SIZE); + tensor = tbufVECIN.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECOUT +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECOUT; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECOUT; + GetTPipePtr()->InitBuffer(tbufVECOUT, ArchTag::UB_SIZE); + tensor = tbufVECOUT.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECCALC +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECCALC; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECCALC; + GetTPipePtr()->InitBuffer(tbufVECCALC, ArchTag::UB_SIZE); + tensor = tbufVECCALC.Get(); + } +}; + +} // namespace Act::Arch + +#endif // INCLUDE_ACT_ARCH_MEMORY_H diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/resource.hpp b/csrc/utils/op_kernel/operator/catlass/act/arch/resource.hpp new file mode 100644 index 00000000000..713679810b8 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/arch/resource.hpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef INCLUDE_ACT_ARCH_RESOURCE_HPP +#define INCLUDE_ACT_ARCH_RESOURCE_HPP + +#include "../../act/act.hpp" +#include "../../act/arch/local_tensor_buffer.hpp" + +namespace Act::Arch { + +template +struct Resource { +public: + AscendC::TPipe pipe; + + LocalTensorBuffer l1Buf; + LocalTensorBuffer l0ABuf; + LocalTensorBuffer l0BBuf; + LocalTensorBuffer l0CBuf; + LocalTensorBuffer ubBuf; + + ACT_DEVICE + Resource() + { + // The initialization of AscendC::Tpipe will insert some synchronization + // interfaces, which may conflict with the usage by users. Therefore, the + // "destroy" interface is used for releasing. + pipe.Destroy(); + } +}; + +} // namespace Act::Arch + +#endif // INCLUDE_ACT_ARCH_RESOURCE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/coord.hpp b/csrc/utils/op_kernel/operator/catlass/act/coord.hpp new file mode 100644 index 00000000000..5faf5be6640 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/coord.hpp @@ -0,0 +1,311 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_COORD_HPP +#define ACT_COORD_HPP + +#include "../act/act.hpp" + +namespace Act { + +/// Statically-sized array specifying Coords within a tensor +template +struct Coord { +public: + // Number of elements in Coord + static const int RANK = RANK_; + + // Index typen used to store elements + using Index = Index_; + + // Type used to represent linear offsets + using LongIndex = LongIndex_; + + // Default ctor initializes uniformly + ACT_HOST_DEVICE constexpr explicit Coord(Index value = Index(0)) + { + for (int i = 0; i < RANK; ++i) { + idx[i] = value; + } + } + + // Constructs from an array of integers + ACT_HOST_DEVICE constexpr Coord(Index const (&idx_)[RANK]) + { + for (int i = 0; i < RANK; ++i) { + idx[i] = idx_[i]; + } + } + + // Constructs from an array of integers + ACT_HOST_DEVICE + int Argmin() const + { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] < idx[i]) { + i = j; + } + } + return i; + } + + // Returns the index of the dimension with greatest value + ACT_HOST_DEVICE + int Argmax() const + { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] > idx[i]) { + i = j; + } + } + return i; + } + + // Returns true if Coord is non-zero + ACT_HOST_DEVICE + explicit operator bool() const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return true; + } + } + return false; + } + + // Return true if Coord is uniformly zero. + ACT_HOST_DEVICE + bool operator!() const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return false; + } + } + return true; + } + + // Element-wise addition + ACT_HOST_DEVICE + Coord operator+(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; + } + + // Add a scalar to each element + ACT_HOST_DEVICE + Coord operator+(const Index val) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + val; + } + return c; + } + + // Element-wise subtraction + ACT_HOST_DEVICE + Coord operator-(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; + } + + // Subtract a scalar from each element + ACT_HOST_DEVICE + Coord operator-(Index const val) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] - val; + } + return c; + } + + // Element-wise multiply + ACT_HOST_DEVICE + Coord operator*(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; + } + + // Element-wise division + ACT_HOST_DEVICE + Coord operator/(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; + } + + // Element-wise mod + ACT_HOST_DEVICE + Coord operator%(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] % b.idx[i]; + } + return c; + } + + // In-place addition + ACT_HOST_DEVICE + Coord &operator+=(Coord const &b) + { + for (int i = 0; i < RANK; ++i) { + idx[i] += b.idx[i]; + } + return *this; + } + + // In-place equal + ACT_HOST_DEVICE + bool operator==(Coord const &b) const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != b.idx[i]) { + return false; + } + } + return true; + } + + // In-place equal + ACT_HOST_DEVICE + bool operator==(Index const val) const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != val) { + return false; + } + } + return true; + } + + // Member access operator + ACT_HOST_DEVICE + Index &operator[](int dim) + { + return idx[dim]; + } + + // Member access operator + ACT_HOST_DEVICE + Index const &operator[](int dim) const + { + return idx[dim]; + } + + // Gets the index of a given Coord element + template + ACT_HOST_DEVICE Index &At() + { + return idx[DIM]; + } + + // Access via index; may limit unrolling potential + ACT_HOST_DEVICE + Index &At(int dim) + { + return idx[dim]; + } + + // Gets the index of a given Coord element + template + ACT_HOST_DEVICE Index const &At() const + { + return idx[DIM]; + } + + // Access via index; may limit unrolling potential + ACT_HOST_DEVICE + Index const &At(int dim) const + { + return idx[dim]; + } + + template + ACT_HOST_DEVICE auto GetCoordByAxis() const + { + Index idx_[sizeof...(Is)]{idx[Is]...}; + return Coord{idx_}; + } + + ACT_HOST_DEVICE + static Coord Min(Coord const &a, Coord const &b) + { + Coord res; + for (int i = 0; i < RANK; ++i) { + res[i] = a[i] < b[i] ? a[i] : b[i]; + } + return res; + } + +private: + // Indices + Index idx[RANK]; +}; + +// Helper to make a 1-element coordinate +template +ACT_HOST_DEVICE constexpr Coord<1, T> MakeCoord(T dim0) +{ + T values[1] = {dim0}; + return Coord<1, T>(values); +} + +/// Helper to make a 2-element coordinate +template +ACT_HOST_DEVICE constexpr Coord<2, T> MakeCoord(T dim0, T dim1) +{ + T values[2] = {dim0, dim1}; + return Coord<2, T>(values); +} + +/// Helper to make a 3-element coordinate +template +ACT_HOST_DEVICE constexpr Coord<3, T> MakeCoord(T dim0, T dim1, T dim2) +{ + T values[3] = {dim0, dim1, dim2}; + return Coord<3, T>(values); +} + +/// Helper to make a 4-element coordinate +template +ACT_HOST_DEVICE constexpr Coord<4, T> MakeCoord(T dim0, T dim1, T dim2, T dim3) +{ + T values[4] = {dim0, dim1, dim2, dim3}; + return Coord<4, T>(values); +} + +} // namespace Act + +#endif // ACT_COORD_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/alignment.hpp b/csrc/utils/op_kernel/operator/catlass/act/detail/alignment.hpp new file mode 100644 index 00000000000..db40e7ba760 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/detail/alignment.hpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ALIGNMENT_HPP +#define ACT_ALIGNMENT_HPP + +#include "../../act/detail/macros.hpp" + +template +ACT_HOST_DEVICE constexpr T RoundUp(const T &val) +{ + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return (val + ALIGN - 1) / ALIGN * ALIGN; +} + +template +ACT_HOST_DEVICE constexpr T RoundUp(const T &val, const T align) +{ + return (val + align - 1) / align * align; +} + +template +ACT_HOST_DEVICE constexpr T RoundDown(const T val) +{ + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return val / ALIGN * ALIGN; +} + +template +ACT_HOST_DEVICE constexpr T RoundDown(const T val, const T align) +{ + return val / align * align; +} + +template +ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend) +{ + static_assert(DIVISOP != 0, "DIVISOP must not be 0"); + return (dividend + DIVISOP - 1) / DIVISOP; +} + +template +ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend, const T divisor) +{ + return (dividend + divisor - 1) / divisor; +} + +#endif // ACT_ALIGNMENT_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/callback.hpp b/csrc/utils/op_kernel/operator/catlass/act/detail/callback.hpp new file mode 100644 index 00000000000..7475213c547 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/detail/callback.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_CALLBACK_HPP +#define ACT_DETAIL_CALLBACK_HPP + +#include "../../act/detail/macros.hpp" + +/// @brief Callback is an alternative to std::function, providing a +/// general carrier of callable structure with no parameters and no return +/// value. Compared with function pointers of type void (*)(), Callback can +/// carry lambda expressions with captures, and does not need to pay attention +/// to the captured content. It should be noted that Callback itself does not +/// store the callable structure it carries like std::function, so +/// it is necessary to ensure that it is used within the life cycle of the +/// callable structure. +struct Callback { + void const *func{nullptr}; + void (*caller)(void const *){nullptr}; + + Callback() = default; + + ACT_DEVICE + void operator()() const + { + if (func) { + caller(func); + } + } + + ACT_DEVICE + operator bool() const + { + return func != nullptr; + } +}; + +template +ACT_DEVICE static void FuncWrapper(void const *func) +{ + (*static_cast(func))(); +} + +// Use this to make a callback +template +ACT_DEVICE Callback MakeCallback(Func *func) +{ + Callback callback; + callback.func = func; + callback.caller = &FuncWrapper; + return callback; +} + +#endif // ACT_DETAIL_CALLBACK_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp b/csrc/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp new file mode 100644 index 00000000000..c9985a05d60 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_DEPENDENT_FALSE_HPP +#define ACT_DETAIL_DEPENDENT_FALSE_HPP + +template +constexpr bool DEPENDENT_BOOL_VALUE = VALUE; + +template +constexpr bool DEPENDENT_FALSE = DEPENDENT_BOOL_VALUE; + +#endif // ACT_DETAIL_DEPENDENT_FALSE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/macros.hpp b/csrc/utils/op_kernel/operator/catlass/act/detail/macros.hpp new file mode 100644 index 00000000000..a2825344653 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/detail/macros.hpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_MACROS_HPP +#define ACT_DETAIL_MACROS_HPP + +#define ACT_DEVICE __forceinline__[aicore] +#define ACT_HOST_DEVICE __forceinline__[host, aicore] +#define ACT_GLOBAL __global__[aicore] + +#endif // ACT_DETAIL_MACROS_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp b/csrc/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp new file mode 100644 index 00000000000..033a4ee4872 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_TAG_TO_LAYOUT_HPP +#define ACT_DETAIL_TAG_TO_LAYOUT_HPP + +#include "../../act/layout/layout.hpp" +#include "../../tla/layout.hpp" + +using namespace tla; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace Act::detail { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For each Act::layout, provides its corresponding tla layout types +template +struct TagToLayout { + using type = LayoutTag; +}; + +template +struct TagToLayout { + using type = Layout, Stride>, Shape>; +}; + +template +struct TagToLayout { + using type = Layout, Stride, int64_t>, Shape>; +}; + +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = Layout, uint32_t>, Shape, uint32_t>>, + Stride, Int>, Stride, int64_t>>, + Shape>; +}; + +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = Layout, uint32_t>, Shape, uint32_t>>, + Stride, int64_t>, Stride, Int>>, + Shape>; +}; + +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = Layout, uint32_t>, Shape, uint32_t>>, + Stride, int64_t>, Stride, Int>>, + Shape>; +}; + +// Convenience aliases +template +using TagToLayout_t = typename TagToLayout::type; + +constexpr uint32_t ELE_NUM_PER_FRACTAL_L0C = 256; +using LayoutL0C = Layout, uint32_t>, Shape, uint32_t>>, + Stride, Int>, Stride, int64_t>>, + Shape>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::detail + +#endif // ACT_DETAIL_TAG_TO_LAYOUT_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp new file mode 100644 index 00000000000..bb7a6ac68b5 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP +#define ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Block { + +template +class BlockEpilogue +{ + static_assert(DEPENDENT_FALSE, "Could not find an epilogue specialization"); +}; + +} // namespace Act::Epilogue::Block + +#include "../../../act/epilogue/block/block_epilogue_per_token_dequant.hpp" +#endif // ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp new file mode 100644 index 00000000000..b2a41ca66e7 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp @@ -0,0 +1,763 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP +#define ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP + +#include "../../../../cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h" +#include "../../../act/act.hpp" +#include "../../../act/arch/resource.hpp" +#include "../../../act/detail/callback.hpp" +#include "../../../act/epilogue/dispatch_policy.hpp" +#include "../../../act/gemm_coord.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../act/matrix_coord.hpp" + +#define ENABLE_EP_SEND_COUNT_HASH 0 + +namespace Act::Epilogue::Block { + +template +class BlockEpilogue, CType_, ScaleType_, PerTokenScaleType_, + DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = typename ScaleType_::Element; + using LayoutScale = typename ScaleType_::Layout; + using ElementPerTokenScale = typename PerTokenScaleType_::Element; + using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v) && + std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + ACT_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(float); + ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubMul; + } + + ACT_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + ACT_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + ACT_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + callback(); + + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); + tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubScaleFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleFp32; + AscendC::LocalTensor ubPerTokenScaleFp32Brcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +template +class BlockEpilogue, CType_, Gemm::GemmType, + Gemm::GemmType, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, + TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + ACT_DEVICE void AlignUbOffset() + { + size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); + if (ubMask != 0) { + ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; + } + } + + ACT_DEVICE + BlockEpilogue(Arch::Resource &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, + Params const ¶ms = Params{}) + : resource(resource), calcInfo(calcInfo), params(params) + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleBrcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubCFp32; + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AlignUbOffset(); + epSendCountLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); + AlignUbOffset(); + AscendC::GlobalTensor epSendCountGM; + epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); + uint32_t epSendCountSize = calcInfo.isShardExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_; + AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast(epSendCountSize * sizeof(uint32_t)), + 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); + AscendC::SetFlag(eventMTE2S); + AscendC::WaitFlag(eventMTE2S); +#if ENABLE_EP_SEND_COUNT_HASH + tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + uint32_t maxGroupSendCount = 0; + uint32_t groupSendCount = 0; + for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) { + uint32_t prevGroupSendCount = groupSendCount; + groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1); + if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { + maxGroupSendCount = groupSendCount - prevGroupSendCount; + } + } + ubOffset += maxGroupSendCount * sizeof(int32_t); + AlignUbOffset(); + // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or + // AscendC::TOTAL_VEC_LOCAL_SIZE +#endif + } + } + + ACT_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + ACT_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + ACT_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U) + { + return (GM_ADDR)((calcInfo.epRankId_ == rankId) + ? calcInfo.epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; + } +#if ENABLE_EP_SEND_COUNT_HASH + ACT_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen) + { + constexpr uint32_t DUPLICATE_MASK_COUNT = 8; + uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); + if (hashOffsetMask != 0) { + uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; + if (copyLen < remainMaskCount) { + remainMaskCount = copyLen; + } + uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; + AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1, + DUPLICATE_MASK_COUNT); + hashOffset += remainMaskCount; + copyLen -= remainMaskCount; + } + if (copyLen > 0) { + AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen); + hashOffset += copyLen; + } + } +#endif + + ACT_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) + { + if ((calcInfo.isShardExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) { + remoteEpRank = calcInfo.epRankId_; + localEpRank = epRank; + } else { + remoteEpRank = epRank; + localEpRank = calcInfo.epRankId_; + } + } + + ACT_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, layout::RowMajor &layoutGmTileD, + LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD) + { + const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); + const uint32_t copyTokenSrcStride = + (layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD)); + const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); + + int64_t offsetD = groupOffsetD + tileOffsetD; + uint32_t startToken = offsetD / calcInfo.axisH_; + uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; + uint32_t itToken = startToken; + uint32_t endToken = startToken + layoutGmTileD.shape(0); +#if ENABLE_EP_SEND_COUNT_HASH + uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); +#else + constexpr uint32_t epRankStart = 0; +#endif + uint32_t sendCount = + expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); + for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + if (prevSendCount <= itToken && itToken < sendCount) { + uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken; + AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride, + copyTokenDstStride, 0); + uint32_t remoteEpRank; + uint32_t localEpRank; + SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); + GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + + localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_; + AscendC::GlobalTensor rankWindow; + rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); + AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset], + ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); + itToken += copyTokenCount; + } + } + } + + ACT_DEVICE + void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK, + GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutBlockC, + Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + expertOffset = expertIdx * calcInfo.epWorldSize_; +#if ENABLE_EP_SEND_COUNT_HASH + if (currentExpertIdx_ != expertIdx) { + uint32_t hashOffset = 0; + uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); + for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount); + } + AscendC::SetFlag(eventVS); + AscendC::WaitFlag(eventVS); + currentExpertIdx_ = expertIdx; + } +#endif + } + + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubMul, ubCFp32, ubScale); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto tileOffsetD = params.layoutD.GetOffset(tileOffset); + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD); + } else { + auto gmTileD = gmD[tileOffsetD]; + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + } + + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + Arch::Resource &resource; + MoeDistributeCombineImpl::CombineCalcInfo calcInfo; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + AscendC::LocalTensor epSendCountLocal_; +#if ENABLE_EP_SEND_COUNT_HASH + AscendC::LocalTensor tokenToEpRankHashLocal_; + uint32_t currentExpertIdx_{static_cast(-1)}; +#endif + + size_t ubOffset{0}; + int32_t eventVMTE2{0}; + int32_t eventMTE2V{0}; + int32_t eventMTE3V{0}; + int32_t eventVMTE3{0}; + int32_t eventVS{0}; + int32_t eventMTE2S{0}; + + uint32_t expertOffset; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleBrcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Act::Epilogue::Block + +#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp new file mode 100644 index 00000000000..8d93192ddc0 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_DISPATCH_POLICY_HPP +#define ACT_EPILOGUE_DISPATCH_POLICY_HPP + +#include "../../act/arch/arch.hpp" + +namespace Act::Epilogue { + +// For AtlasA2, an element wise epilogue of the form D = C + X, where X is an +// additional source +struct EpilogueAtlasA2ElemWiseOneSource { + using ArchTag = Arch::AtlasA2; + // Number of operands. Including C, X, and D 3 operands + static constexpr uint32_t OPERANDS_NUM = 3; +}; + +// For AtlasA2, FA Softmax +struct EpilogueAtlasA2FASoftmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, FA RescaleO +struct EpilogueAtlasA2FARescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA Softmax +struct EpilogueAtlasA2MLASoftmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA RescaleO +struct EpilogueAtlasA2MLARescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA FD RescaleO +template +struct EpilogueAtlasA2MLAFDRescaleO { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t KV_SPLIT_MAX = 64; + static constexpr uint32_t HEADS_PROCESS_MAX = 16; + static constexpr uint32_t COMPUTE_ELE_NUM = COMPUTE_ELE_NUM_; +}; + +// For AtlasA2, MLA TP1 Softmax +struct EpilogueAtlasA2MLATP1Softmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA TP1 RescaleO +struct EpilogueAtlasA2MLATP1RescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, per token dequant +template +struct EpilogueAtlasA2PerTokenDequant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; +} // namespace Act::Epilogue + +#endif // ACT_EPILOGUE_DISPATCH_POLICY_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp new file mode 100644 index 00000000000..1a9d3b4048c --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP +#define ACT_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" + +namespace Act::Epilogue::Tile { + +template +struct CopyGm2Ub { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy gm to ub, can not find the specialization."); +}; + +template +struct CopyGm2Ub> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyGm2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(layoutSrc.shape(0), layoutSrc.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) * sizeof(Element), + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; +}; + +template +struct CopyGm2Ub> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyGm2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(1, layoutSrc.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; +}; + +/// @brief This copy instruction used to copy per token scale from GM to UB. +/// Copy the scale of shape (m,1) on GM to the first column of shape (m,n) on +/// UB, and pad the first block of each row (i.e. pad to shape (m,8) when +/// element type is float). +/// @tparam ArchTag: Architecture tag. +/// @tparam GmType: Type of data on GM. +template +struct CopyPerTokenScale2Ub { + static_assert(std::is_same_v, + "Unsupporteded layout for CopyPerTokenScale2Ub."); + + using Element = typename GmType::Element; + using LayoutSrc = typename GmType::Layout; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyPerTokenScale2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams; + AscendC::DataCopyPadExtParams padParams; + + dataCopyParams.blockCount = layoutSrc.shape(0); + dataCopyParams.blockLen = layoutSrc.shape(1) * sizeof(Element); // per token scale has only one column + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + // Pad the data to the complete block + padParams.isPad = true; + padParams.leftPadding = 0; + padParams.rightPadding = 0; + + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + } +}; + +template +struct CopyGm2UbAligned { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy gm to ub aligned, can not find the specialization."); +}; + +template +struct CopyGm2UbAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + ACT_DEVICE + CopyGm2UbAligned() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp new file mode 100644 index 00000000000..651f4342d4c --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP +#define ACT_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" + +namespace Act::Epilogue::Tile { + +template +struct CopyUb2Gm { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy ub to gm, can not find the specialization."); +}; + +template +struct CopyUb2Gm> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyUb2Gm() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(layoutDst.shape(0), layoutDst.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_C0, + (layoutDst.stride(0) - layoutDst.shape(1)) * sizeof(Element), 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + } +}; + +// new add vectorlayout version +template +struct CopyUb2Gm> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyUb2Gm() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(1, layoutDst.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + }; +}; + +template +struct CopyUb2GmAligned { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy ub to gm aligned, can not find the specialization."); +}; + +template +struct CopyUb2GmAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + ACT_DEVICE + CopyUb2GmAligned() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + uint32_t rows = layoutDst.shape(0); + uint32_t cols = layoutDst.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp new file mode 100644 index 00000000000..a4a9d8d696b --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileBroadcastInplaceByColumn { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileBroadcastInplaceByColumn() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) + { + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + constexpr uint32_t blkNumPerRow = TileShape::COLUMN / eleNumPerBlk; + + constexpr uint64_t defaultMask = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + constexpr uint64_t tailMask = (TileShape::ROW % BLK_NUM_PER_VECTOR_FRACTAL) * eleNumPerBlk; + + constexpr uint8_t repeatTimes = 1; + + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = blkNumPerRow; + repeatParams.srcStride = blkNumPerRow; + repeatParams.dstRepeatSize = 1; + repeatParams.srcRepeatSize = 1; + + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += BLK_NUM_PER_VECTOR_FRACTAL) { + uint64_t mask = ((TileShape::ROW - rowOffset) >= BLK_NUM_PER_VECTOR_FRACTAL) ? defaultMask : tailMask; + for (uint32_t colOffset = eleNumPerBlk; colOffset < TileShape::COLUMN; colOffset += eleNumPerBlk) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN + colOffset], + ubInOut[rowOffset * TileShape::COLUMN], mask, 1, repeatParams); + } + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp new file mode 100644 index 00000000000..7ea15659aca --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileBroadcastInplaceByRow { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileBroadcastInplaceByRow() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) + { + constexpr uint32_t eleNumPerVectorFractal = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + + constexpr uint64_t mask = eleNumPerVectorFractal; + constexpr uint8_t repeatTimes = TileShape::COLUMN / eleNumPerVectorFractal; + + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = 1; + repeatParams.srcStride = 1; + repeatParams.dstRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + repeatParams.srcRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + + for (uint32_t rowOffset = 1; rowOffset < TileShape::ROW; ++rowOffset) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN], ubInOut, mask, repeatTimes, repeatParams); + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp new file mode 100644 index 00000000000..93b6125f772 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +/// BroadcastMul computes the elementwise multiplication of a tensor of shape +/// (m, n) and a tensor of shape (m, n) after broadcasting. There are two +/// broadcast modes: row-broadcast and column-broadcast. + +/// @brief Computes the elementwise multiplication of a tensor with shape (m, n) +/// and a tensor with original shape (1, n) broadcast to (m, n). +/// @tparam ArchTag_ is the architecture tag. +/// @tparam ComputeType_ includes the element type and layout information. +/// @tparam TileShape_ is the shape (m, n). +template +struct TileRowBroadcastMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileRowBroadcastMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + constexpr uint32_t maxRepeatTimes = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = blkNumPerColumn; + repeatParams.src0RepStride = blkNumPerColumn; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = maxRepeatTimes; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint8_t repeatTimes = static_cast((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN; + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[colOffset], mask, repeatTimes, + repeatParams); + } + } + } +}; + +/// @brief Compute the elementwise multiplication of a tensor of shape (m, n) +/// and a tensor of shape (m, eleNumPerBlk), which is broadcast from a tensor of +/// shape (m, 1), broadcast to (m, n). +/// @tparam ArchTag_ is the architecture tag. +/// @tparam ComputeType_ includes the element type and layout information. +/// @tparam TileShape_ is the shape (m, n). +template +struct TileOneBlkColumnBroadcastMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileOneBlkColumnBroadcastMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = blkNumPerColumn; + repeatParams.src0BlkStride = blkNumPerColumn; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 1; + repeatParams.src0RepStride = 1; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = BLK_NUM_PER_VECTOR_FRACTAL; + constexpr uint32_t colNumPerCompute = eleNumPerBlk * maxRepeatNum; + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint64_t mask = ((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM) * eleNumPerBlk; + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint8_t repeatTimes = + static_cast(((residueN > colNumPerCompute) ? colNumPerCompute : residueN) / eleNumPerBlk); + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[rowOffset * eleNumPerBlk], mask, + repeatTimes, repeatParams); + } + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp new file mode 100644 index 00000000000..d8f7d79d939 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template +struct TileBroadcastOneBlk { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileBroadcastOneBlk() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) + { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + AscendC::BrcbRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.dstRepStride = BLK_NUM_PER_VECTOR_FRACTAL; + + constexpr uint32_t eleNumPerCompute = RoundDown(maxRepeatNum * BLK_NUM_PER_VECTOR_FRACTAL); + for (uint32_t offset = 0; offset < COMPUTE_LENGTH; offset += eleNumPerCompute) { + uint32_t residueM = COMPUTE_LENGTH - offset; + uint32_t computeM = (residueM > eleNumPerCompute) ? eleNumPerCompute : residueM; + uint8_t repeatTimes = static_cast(CeilDiv(computeM)); + AscendC::Brcb(ubOut[offset * eleNumPerBlk], ubIn[offset], repeatTimes, repeatParams); + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp new file mode 100644 index 00000000000..5016251660a --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_CAST_HPP +#define ACT_EPILOGUE_TILE_TILE_CAST_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class DstType_, class SrcType_, + /// Length of the compute buffer + class TileShape_> +struct TileCast { + using ArchTag = ArchTag_; + using ElementDst = typename DstType_::Element; + using ElementSrc = typename SrcType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileCast() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) + { + AscendC::Cast(ubOut, ubIn, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp new file mode 100644 index 00000000000..2ed7c9c7a5e --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_COPY_HPP +#define ACT_EPILOGUE_TILE_TILE_COPY_HPP + +#include "../../../act/epilogue/tile/copy_gm_to_ub.hpp" +#include "../../../act/epilogue/tile/copy_ub_to_gm.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag, class... Args> +struct TileCopy { + static_assert(DEPENDENT_FALSE, "Unsupporteded tile copy, can not find the specialization."); +}; + +template +struct TileCopy { + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; + using CopyGmToUbY = CopyGm2Ub; + using CopyGmToUbTemp = CopyGm2Ub; + using CopyUbToGmZ = CopyUb2Gm; +}; + +template +struct TileCopy { + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementY = typename YType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyGmToUbY = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; + +template +struct TileCopyBf16 { + using ElementC = typename CType::Element; + using ElementX = bfloat16_t; + using ElementY = bfloat16_t; + using ElementD = bfloat16_t; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub>; + using CopyGmToUbY = CopyGm2Ub>; + using CopyUbToGmD = CopyUb2Gm>; +}; + +template +struct TileCopyPerTokenDequant { + using ElementC = typename CType::Element; + using ElementScale = typename ScaleType::Element; + using ElementPerTokenScale = typename PerTokenScaleType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbScale = CopyGm2Ub; + using CopyGmToUbPerTokenScale = CopyPerTokenScale2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; + +template +struct TileCopyPerTokenDequantGemm { + using ElementX = typename XType::Element; + using ElementScale = typename ScaleType::Element; + using ElementPerTokenScale = typename PerTokenScaleType::Element; + using ElementBias = typename BiasType::Element; + using ElementC = typename CType::Element; + + using CopyGmToUbX = CopyGm2Ub; + using CopyGmToUbScale = CopyGm2Ub; + using CopyGmToUbPerTokenScale = CopyGm2Ub; + using CopyGmToUbBias = CopyGm2Ub; + using CopyUbToGmC = CopyUb2Gm; +}; + +} // namespace Act::Epilogue::Tile + +#endif // ACT_EPILOGUE_TILE_TILE_COPY_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp new file mode 100644 index 00000000000..8edcc1f9ba7 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP +#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + uint32_t COMPUTE_LENGTH_> +struct TileElemWiseAdd { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileElemWiseAdd() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + // Do the calculation + AscendC::Add(ubOut, ubIn0, ubIn1, COMPUTE_LENGTH); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp new file mode 100644 index 00000000000..cfc45739023 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP +#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileElemwiseMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileElemwiseMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + // Do the calculation + AscendC::Mul(ubOut, ubIn0, ubIn1, TileShape::COUNT); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp new file mode 100644 index 00000000000..9bf10fa9973 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP +#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP + +#include "../../../act/gemm/helper.hpp" + +namespace Act::Epilogue::Tile { +template +struct TileElemWiseMuls { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileElemWiseMuls() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstLocal, AscendC::LocalTensor srcTensor, + ElementCompute scalar) + { + AscendC::Muls(dstLocal, srcTensor, scalar, COMPUTE_LENGTH); + } +}; +} // namespace Act::Epilogue::Tile + +#endif // ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp new file mode 100644 index 00000000000..490a2a5a483 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP +#define ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP + +#include "../../../act/act.hpp" +#include "../../../act/detail/alignment.hpp" +#include "../../../act/matrix_coord.hpp" + +namespace Act::Epilogue::Tile { + +struct EpilogueIdentityTileSwizzle { + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + ACT_DEVICE + EpilogueIdentityTileSwizzle() = default; + + ACT_DEVICE + EpilogueIdentityTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) + { + loopsMN = CeilDiv(blockShape, tileShape); + } + + ACT_DEVICE + uint32_t GetLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + ACT_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const + { + return MatrixCoord{loopIdx / loopsMN.column(), loopIdx % loopsMN.column()}; + } + + ACT_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const + { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } +}; + +struct EpilogueHorizontalTileSwizzle { + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + ACT_DEVICE + EpilogueHorizontalTileSwizzle() = default; + + ACT_DEVICE + EpilogueHorizontalTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) + { + loopsMN = CeilDiv(blockShape, tileShape); + } + + ACT_DEVICE + uint32_t GetLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + ACT_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const + { + return MatrixCoord{loopIdx % loopsMN.row(), loopIdx / loopsMN.row()}; + } + + ACT_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const + { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif // ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp new file mode 100644 index 00000000000..8da81c80f52 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_BLOCK_BLOCK_MMAD_HPP +#define ACT_GEMM_BLOCK_BLOCK_MMAD_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/tile/tile_copy.hpp" +#include "../../../act/gemm/tile/tile_mmad.hpp" + +namespace Act::Gemm::Block { + +template , + class TileMmad = Gemm::Tile::TileMmad> +struct BlockMmad { + static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); +}; + +template , + class TileMmad = Gemm::Tile::TileMmadTla> +struct BlockMmadTla { + static_assert(DEPENDENT_FALSE, "BlockMmadTla is not implemented for this DispatchPolicy"); +}; + +/// new add for the reason that i am using the dispatchpolicy which is same as +/// the policy of the optimized_matmul +// so i add a new one class to avoid the conflict +template , // change the name + class TileMmad = Gemm::Tile::TileMmad> +struct BlockGemm { + static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); +}; + +} // namespace Act::Gemm::Block + +#include "../../../act/gemm/block/block_mmad_preload_async_with_callback.hpp" + +#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp new file mode 100644 index 00000000000..324f97998bc --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp @@ -0,0 +1,410 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP +#define ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP + +#include "../../../act/act.hpp" +#include "../../../act/arch/resource.hpp" +#include "../../../act/coord.hpp" +#include "../../../act/detail/callback.hpp" +#include "../../../act/gemm/dispatch_policy.hpp" +#include "../../../act/gemm/helper.hpp" +#include "../../../act/gemm_coord.hpp" + +namespace Act::Gemm::Block { + +template +struct BlockMmad, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2PreloadAsyncWithCallback; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on " + "the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + ACT_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + ACT_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + } + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe, + Callback const &callbackAfterFixpipe) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = + (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + // If the number of preload instructions reaches the upper limit, perform + // an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1ListId = l1ListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; + l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + } + } + + ACT_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + +private: + struct L1TileMmadParams { + uint32_t l1ListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callbackBeforeFixpipe; + Callback callbackAfterFixpipe; + + ACT_DEVICE + L1TileMmadParams() = default; + }; + + ACT_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; + for (uint32_t i = 0; i < L1_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1AEventList[i] = i; + l1BEventList[i] = i + L1_STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + } + } + + ACT_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + ACT_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + ACT_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + ACT_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1ListId]; + auto &l1BTensor = l1BTensorList[params.l1ListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1ListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1ListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1ListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1ListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the + // accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the + // calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && + (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + + params.callbackBeforeFixpipe(); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + + params.callbackAfterFixpipe(); + } + } + + AscendC::LocalTensor l1ATensorList[L1_STAGES]; + AscendC::LocalTensor l1BTensorList[L1_STAGES]; + int32_t l1AEventList[L1_STAGES]; + int32_t l1BEventList[L1_STAGES]; + uint32_t l1ListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Act::Gemm::Block + +#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp new file mode 100644 index 00000000000..36662d2a9d6 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP +#define ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP + +#include "../../../act/act.hpp" +#include "../../../act/detail/alignment.hpp" +#include "../../../act/gemm_coord.hpp" +#include "../../../act/matrix_coord.hpp" + +namespace Act::Gemm::Block { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Block swizzling function for Gemms +template +struct GemmIdentityBlockSwizzle { + /// Data members + + GemmCoord problemShape; + MatrixCoord tileMN; + MatrixCoord loopsMN; + + /// Methods + + ACT_DEVICE + GemmIdentityBlockSwizzle() {} + + ACT_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + : problemShape(problemShape_), tileMN(tileMN_) + { + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + ACT_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) + : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) + {} + + ACT_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + ACT_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + loopsMN = loopsMN_; + } + + ACT_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + ACT_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) + { + return taskIdx / (GetCoreLoops()); + } + + ACT_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) + { + uint32_t innerIdx = taskIdx % GetCoreLoops(); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMN.column() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMN.row() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } + } + + ACT_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord) + { + uint32_t mActual = + (blockCoord.m() == (loopsMN.row() - 1)) ? (problemShape.m() - blockCoord.m() * tileMN.row()) : tileMN.row(); + uint32_t nActual = (blockCoord.n() == (loopsMN.column() - 1)) + ? (problemShape.n() - blockCoord.n() * tileMN.column()) + : tileMN.column(); + uint32_t kActual = problemShape.k(); + return GemmCoord{mActual, nActual, kActual}; + } +}; + +/// Block swizzling function for Splitk Gemms +template +struct SplitkGemmIdentityBlockSwizzle { + /// Data members + + GemmCoord problemShape; + GemmCoord tileShape; + GemmCoord loopsMNK; + uint32_t splitkFactor = 1; // split k dim into virtual cores + + /// Methods + + ACT_DEVICE + SplitkGemmIdentityBlockSwizzle() {} + + ACT_DEVICE + SplitkGemmIdentityBlockSwizzle(GemmCoord const &problemShape_, GemmCoord const &tileShape_, + uint32_t splitkFactor_ = 1) + : problemShape(problemShape_), tileShape(tileShape_), splitkFactor(splitkFactor_) + { + loopsMNK = CeilDiv(problemShape, tileShape); + } + + ACT_DEVICE + uint32_t GetKIdxBySplitkSliceIdx(uint32_t splitkSliceIdx) const + { + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + return (loopsMNK.k() / splitkFactor + 1) * splitkSliceIdx; + } else { + return splitkSliceIdx * (loopsMNK.k() / splitkFactor) + loopsMNK.k() % splitkFactor; + } + } + + ACT_DEVICE + uint32_t GetSplitkSliceIdx(uint32_t taskIdx) const + { + uint32_t mnLoops = loopsMNK.m() * loopsMNK.n(); + return taskIdx % GetCoreLoops() / mnLoops; + } + + ACT_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMNK.m() * loopsMNK.n() * splitkFactor; + } + + ACT_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) + { + return taskIdx / GetCoreLoops(); + } + + ACT_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) + { + uint32_t splitkSliceIdx = GetSplitkSliceIdx(taskIdx); + uint32_t kIdx = GetKIdxBySplitkSliceIdx(splitkSliceIdx); + + uint32_t innerIdx = taskIdx % (loopsMNK.m() * loopsMNK.n()); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMNK.m(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.n()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.n()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMNK.m() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMNK.n() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMNK.n(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.m()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.m()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMNK.n() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMNK.m() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } + } + + ACT_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord, uint32_t splitkSliceIdx) + { + uint32_t splitkSliceLen; + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + splitkSliceLen = (loopsMNK.k() / splitkFactor + 1) * tileShape.k(); + } else { + splitkSliceLen = (loopsMNK.k() / splitkFactor) * tileShape.k(); + } + uint32_t mActual = (blockCoord.m() == (loopsMNK.m() - 1)) ? (problemShape.m() - blockCoord.m() * tileShape.m()) + : tileShape.m(); + uint32_t nActual = (blockCoord.n() == (loopsMNK.n() - 1)) ? (problemShape.n() - blockCoord.n() * tileShape.n()) + : tileShape.n(); + uint32_t kActual = (splitkSliceIdx == (splitkFactor - 1)) ? (problemShape.k() - blockCoord.k() * tileShape.k()) + : splitkSliceLen; + return GemmCoord{mActual, nActual, kActual}; + } +}; + +} // namespace Act::Gemm::Block + +#endif // ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp new file mode 100644 index 00000000000..4ec7433f05c --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_DISPATCH_POLICY_HPP +#define ACT_GEMM_DISPATCH_POLICY_HPP + +#include "../../act/act.hpp" + +namespace Act::Gemm { + +// Block Mmad Policies + +template +struct MmadAtlasA2Base { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t ASYNC = ASYNC_; +}; + +using MmadAtlasA2 = MmadAtlasA2Base; +using MmadAtlasA2Async = MmadAtlasA2Base; + +// Now ENABLE_UNIT_FLAG_ must be false when input element is int8 +template +struct MmadAtlasA2Pingpong : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2Preload : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +struct MmadAtlasA2FAQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2FAPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAQKTp1Spec : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAPVTp1Spec : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +template +struct MmadAtlasA2PreloadAsync : public MmadAtlasA2Async { + static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance + static constexpr uint32_t L1_STAGES = L1_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +template +struct MmadAtlasA2PreloadAsyncWithCallback + : public MmadAtlasA2PreloadAsync {}; +} // namespace Act::Gemm + +#endif // ACT_GEMM_DISPATCH_POLICY_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp new file mode 100644 index 00000000000..145c3964650 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_GEMM_TYPE_HPP +#define ACT_GEMM_GEMM_TYPE_HPP + +namespace Act::Gemm { + +//////////////////////////////////////////////////////////////////// + +template +struct GemmType { + using Element = Element_; + using Layout = Layout_; + static constexpr AscendC::TPosition POSITION = POSITION_; +}; + +} // namespace Act::Gemm + +#endif // ACT_GEMM_GEMM_TYPE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/helper.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/helper.hpp new file mode 100644 index 00000000000..bb634f9bc96 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/helper.hpp @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_HELPER_HPP +#define ACT_GEMM_HELPER_HPP + +#include "../../act/act.hpp" +#include "../../act/layout/layout.hpp" +#include "../../tla/layout.hpp" + +namespace Act::Gemm::helper { + +template +struct L1AlignHelper { + static_assert(DEPENDENT_FALSE, "Unsupporteded align helper, can not find the specialization."); +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template +struct ElementAccumulatorSelector { + static_assert(DEPENDENT_FALSE, + "Unsupporteded element accumulator selector, can not find the " + "specialization."); +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = int32_t; +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template +struct L1ATypeSelector { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1AlignHelperTla { + static_assert(DEPENDENT_FALSE, "Unsupporteded align helper tla, can not find the specialization."); +}; + +template +struct L1AlignHelperTla::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelperTla::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +/////////////////////////////////////// +// new add +template +struct L1ATypeSelectorGemm { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1ATypeSelectorGemm> { + using L1AType = Gemm::GemmType; +}; + +template <> +struct L1ATypeSelectorGemm> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelectorGemm> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1BTypeSelectorGemm { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1BTypeSelectorGemm> { + using L1BType = Gemm::GemmType; +}; + +template <> +struct L1BTypeSelectorGemm> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelectorGemm> { + using L1BType = Gemm::GemmType; +}; + +template +struct L0ATypeSelector {}; + +template +struct L0ATypeSelector> { + using L0AType = Gemm::GemmType; +}; + +template +struct L0ATypeSelector> { + using L0AType = Gemm::GemmType; +}; + +template <> +struct L0ATypeSelector> { + using L0AType = Gemm::GemmType; +}; + +template +struct L0BTypeSelectorGemm {}; + +template +struct L0BTypeSelectorGemm> { + using L0BType = Gemm::GemmType; +}; + +template <> +struct L0BTypeSelectorGemm> { + using L0BType = Gemm::GemmType; +}; + +template +struct L0BTypeSelectorGemm> { + using L0BType = Gemm::GemmType; +}; + +template +struct L0BTypeSelectorGemv {}; + +template +struct L0BTypeSelectorGemv> { + using L0BType = Gemm::GemmType; +}; + +template +struct L0BTypeSelectorGemv> { + using L0BType = Gemm::GemmType; +}; + +template <> +struct L0BTypeSelectorGemv> { + using L0BType = Gemm::GemmType; +}; +} // namespace Act::Gemm::helper + +#endif // ACT_GEMM_HELPER_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp new file mode 100644 index 00000000000..d18dc276af2 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP +#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP + +#include "../../../../cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h" +#include "../../../act/act.hpp" +#include "../../../act/arch/cross_core_sync.hpp" +#include "../../../act/arch/resource.hpp" +#include "../../../act/coord.hpp" +#include "../../../act/detail/callback.hpp" +#include "../../../act/gemm_coord.hpp" +#include "../../../act/matrix_coord.hpp" + +namespace Act::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + void *combiner; + + // Methods + ACT_DEVICE + Params() {} + + ACT_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_, + void *combiner_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_), + combiner(combiner_) + {} + }; + + // Methods + ACT_DEVICE + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + ACT_DEVICE void operator()(Params const ¶ms); + + template <> + ACT_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current + // groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + ACT_DEVICE void operator()(Params const ¶ms) + { + auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine *)params.combiner; + { + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); + } + } + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, + layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + icache_preload(4); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->AllToAllSend(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->ReducePermute(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->Process(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; + + ACT_DEVICE + AicWaitFunc() = default; + + ACT_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; + + ACT_DEVICE + AicSetFunc() = default; + + ACT_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Act::Gemm::Kernel + +#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp new file mode 100644 index 00000000000..5100d46f9bc --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp @@ -0,0 +1,798 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_GM_TO_L1_HPP +#define ACT_GEMM_TILE_COPY_GM_TO_L1_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../tla/tensor.hpp" + +using namespace tla; + +namespace Act::Gemm::Tile { + +template +struct CopyGmToL1 { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to l1, can not find the specialization."); +}; + +/// Partial specialization for AtlasA2, half, RowMajor in and zN out. +/// Matrix A confirm +template +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(0); + uint32_t ndNum = layoutSrc.shape(0) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(0) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(1); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; //` + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(0) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(1)], srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = i * idxR0 * layoutDst.stride(1) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(0); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(1); + uint32_t ndNum = layoutSrc.shape(1) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(1) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(3); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(1) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(3)], srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = i * idxR0 * layoutDst.stride(3) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, RowMajor in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } + + // layoutSrc must be the layout of one of the src matrices + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc, uint32_t ndNum, uint32_t srcNdMatrixStride, + uint32_t dstNzNStride, uint32_t dstNzMatrixStride, uint32_t dstNzC0Stride) + { + AscendC::Nd2NzParams intriParams; + + intriParams.nValue = layoutSrc.shape(0); + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = dstNzNStride; + intriParams.dstNzC0Stride = dstNzC0Stride; + if (srcNdMatrixStride < STRIDE_LIMIT) { + intriParams.ndNum = ndNum; + intriParams.srcNdMatrixStride = srcNdMatrixStride; + intriParams.dstNzMatrixStride = dstNzMatrixStride; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.ndNum = 1; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzMatrixStride = 0; + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * srcNdMatrixStride], intriParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } + } +}; + +/// Partial specialization for zN in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(1)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(0)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(3) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(3) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(3); + uint64_t srcOffset = i * layoutSrc.stride(3); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); + } + } + } +}; + +/// Partial specialization for nZ in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(0)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(1)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(1) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(1) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(1); + uint64_t srcOffset = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, PaddingRowMajor in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::PaddingRowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::PaddingColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(1); + intriParams.srcDValue = layoutSrc.stride(2); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// +/// Partial specialization for CopyGmToL1, AtlasA2, RowMajor in and zN out. +template +struct TileCopyTla< + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, + std::enable_if_t::value && tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t nValue = get<0>(srcTensor.shape()); + const uint32_t dValue = get<1>(srcTensor.shape()); + const uint32_t srcDValue = get<0>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = get<0, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = get<1, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], srcTensor.data()[i * srcDValue], intriParams); + } + } + } +}; + +/// Partial specialization for CopyGmToL1, AtlasA2, ColumnMajor in and nZ out. +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t nValue = get<1>(srcTensor.shape()); + const uint32_t dValue = get<0>(srcTensor.shape()); + const uint32_t srcDValue = get<1>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = get<1, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = get<0, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], srcTensor.data()[i * srcDValue], intriParams); + } + } + } +}; + +/// Partial specialization for CopyGmToL1, AtlasA2, PaddingRowMajor in and zN +/// out. +template +struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, + layout::PaddingRowMajor, layout::zN> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = get<1>(srcTensor.orgShape()); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = get<1, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = get<0>(srcTensor.orgShape()); + intriParams.srcDValue = get<0, 0>(srcTensor.stride()); + intriParams.dstNzNStride = get<0, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } +}; + +/// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, +/// PaddingColumnMajor in and nZ out. +template +struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, + layout::PaddingColumnMajor, layout::nZ> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = get<0>(srcTensor.orgShape()); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = get<0, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = get<1>(srcTensor.orgShape()); + intriParams.srcDValue = get<1, 0>(srcTensor.stride()); + intriParams.dstNzNStride = get<1, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_GM_TO_L1_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp new file mode 100644 index 00000000000..d50650056df --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_GM_TO_UB_HPP +#define ACT_GEMM_TILE_COPY_GM_TO_UB_HPP + +#include "../../../act/act.hpp" +#include "../../../tla/tensor.hpp" + +namespace Act::Gemm::Tile { + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct TileCopyTla< + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::VECCALC>, + std::enable_if_t::value && tla::detail::isRowMajor::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::VECCALC>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::DataCopyExtParams dataCopyParams( + get<0>(srcTensor.shape()), get<1>(srcTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) * sizeof(ElementSrc), + (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) / ELE_NUM_PER_BLK, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams, padParams); + }; +}; + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_GM_TO_UB_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp new file mode 100644 index 00000000000..b25e28b0f8d --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP +#define ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP + +#include "../../../act/gemm/gemm_type.hpp" + +namespace Act::Gemm::Tile { + +enum class ScaleGranularity { UNDEFINED = -1, NO_QUANT = 0, PER_TENSOR, PER_CHANNEL, PER_GROUP }; + +template +struct CopyL0CToGmQuantMode { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); +}; + +// CopyL0CToGm cast fp32 to fp16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322F16; +}; + +// CopyL0CToGm cast fp32 to bf16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322BF16; +}; + +// CopyL0CToGm output fp32 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; +}; + +// CopyL0CToGm output int32 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; +}; + +// CopyL0CToGm cast int32_t to fp16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::DEQF16; +}; + +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::VDEQF16; +}; + +template +struct CopyL0CToGm { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::RowMajor; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(1); + intriParams.mSize = dstLayout.shape(0); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); + intriParams.dstStride = dstLayout.stride(0); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, intriParams); + } +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::ColumnMajor; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementDst); + + ACT_DEVICE + CopyL0CToGm() {} + + ACT_DEVICE + void operator()(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::DataCopyCO12DstParams params; + + params.nSize = dstLayout.shape(0); + params.mSize = dstLayout.shape(1); + params.dstStride = dstLayout.stride(1); + params.srcStride = srcLayout.shape(2) * srcLayout.shape(3); + params.quantPre = quantPre; + params.reluPre = 0; + params.channelSplit = false; + params.nz2ndEn = true; + AscendC::DataCopy(dstTensor, srcTensor, params); + } +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::zN; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(2) * dstLayout.shape(3); + intriParams.mSize = dstLayout.shape(0) * dstLayout.shape(1); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.shape(2); + intriParams.dstStride = dstLayout.stride(3) / (BYTE_PER_C0 / sizeof(ElementDst)); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, intriParams); + } +}; + +///////////////////////////////////////////CopyL0CToGmTla///////////////////////////////////////////////// +template +struct CopyL0CToGmTla { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); +}; + +template +struct CopyL0CToGmTla< + Act::Arch::AtlasA2, TensorSrc_, Tensor, LayoutDst_, AscendC::TPosition::GM>, + ScaleGranularity::NO_QUANT, ReluEnable_, std::enable_if_t::value>> { + using ArchTag = Act::Arch::AtlasA2; + using TensorDst = Tensor, LayoutDst_, AscendC::TPosition::GM>; + using ElementDst = ElementDst_; + using TensorSrc = TensorSrc_; + using ElementSrc = typename TensorSrc::Element; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = get<1>(dstTensor.shape()); + intriParams.mSize = get<0>(dstTensor.shape()); + intriParams.srcStride = get<1, 1>(srcTensor.stride()) / get<0, 0>(srcTensor.stride()); + intriParams.dstStride = get<0>(dstTensor.stride()); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dstTensor.data(), srcTensor.data(), + intriParams); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp new file mode 100644 index 00000000000..14639773d46 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp @@ -0,0 +1,392 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP +#define ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../tla/tensor.hpp" + +using namespace tla; + +namespace Act::Gemm::Tile { + +template +struct CopyL1ToL0A { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l1 to l0, can not find the specialization."); +}; + +//////////////////////////////// +/// new add gemm +template +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutSrc.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1) / 2); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = static_cast(layoutSrc.shape(1) / 2) - 1; + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(3)], srcTensor[i * layoutSrc.stride(3)], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::zN; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + uint32_t MRound = layoutSrc.shape(0) * layoutSrc.shape(1); + uint32_t KRound = layoutSrc.shape(2) * layoutSrc.shape(3); + uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; + uint32_t KLoops = CeilDiv(KRound, KL0Alignment); + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(MRound / ELE_NUM_PER_C0); + loadDataParams.srcStride = static_cast(KRound / KL0Alignment); + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < KLoops; i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * MRound * KL0Alignment], + srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); + } + } +}; +////////////////////////////////////////// + +/// Partial specialization for zN in and zZ out. +template +struct CopyL1ToL0A> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +/// Partial specialization for int8_t, nZ in and zZ out. (Transpose A) +template +struct CopyL1ToL0A> { + using Element = int8_t; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(1)) - 1; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// + +/// Partial specialization for CopyL1ToL0A, AtlasA2, zN in and zZ out. +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0A, AtlasA2, nZ in and zZ out. +/// (Transpose A) +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0A, AtlasA2, int8_t, nZ in and zZ out. +/// (Transpose A) +template +struct TileCopyTla< + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::A2>, + std::enable_if_t::value && tla::detail::isnZ::value>> { + using Element = int8_t; + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterShapeRow = get<0, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = dstOuterShapeCol - 1; + + for (uint32_t i = 0; i < srcOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose(dstTensor.data()[i * dstOuterStrideRow * 2], + srcTensor.data()[i * srcOuterStrideRow], loadDataParams); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp new file mode 100644 index 00000000000..6f1ced1d9f7 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp @@ -0,0 +1,537 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP +#define ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../tla/tensor.hpp" + +using namespace tla; + +namespace Act::Gemm::Tile { + +template +struct CopyL1ToL0B { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l1 to l0, can not find the specialization."); +}; + +//////////////////////////////////////// +/// new add gemm +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N + AscendC::LoadData(dstTensor[i * layoutSrc.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3) / 2); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = static_cast(layoutSrc.shape(3) / 2) - 1; + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N + AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(1)], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + uint32_t NRound = layoutSrc.shape(2) * layoutSrc.shape(3); + uint32_t KRound = layoutSrc.shape(0) * layoutSrc.shape(1); + uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; + uint32_t KLoops = CeilDiv(KRound, KL0Alignment); + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(NRound / ELE_NUM_PER_C0); + loadDataParams.srcStride = static_cast(KRound / KL0Alignment); + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < KLoops; i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * NRound * KL0Alignment], + srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); + loadDataParams.srcStride = layoutSrc.shape(3); + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutSrc.shape(3); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = layoutDst.shape(1) * layoutDst.shape(3); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + }; +}; + +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + using Element = float; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(0)) - 1; + + for (uint32_t i = 0; i < CeilDiv<2 * ELE_NUM_PER_C0>(layoutDst.orgShape(1)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3) * 2], srcTensor[i * layoutSrc.stride(3)], + loadDataParams); + } + }; +}; + +template +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nZ; + using Element = int8_t; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(1)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3) * 2], + loadDataParams); + } + } +}; +//////////////////////////////////////////// + +/// Partial specialization for int8_t, zN in and nZ out. +template +struct CopyL1ToL0B> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1) * 2], + loadDataParams); + } + } +}; + +/// Partial specialization for zN in and nZ out. +template +struct CopyL1ToL0B> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +/// Partial specialization for nZ in and nZ out. (Transpose B) +template +struct CopyL1ToL0B> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + if (layoutSrc.shape(3) == layoutDst.shape(3)) { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1) * layoutDst.shape(3)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + } else { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// +/// Partial specialization for CopyL1ToL0B, AtlasA2, zN in and nZ out. +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0B, AtlasA2, nZ in and nZ out. +/// (Transpose B) +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0B, AtlasA2, int8_t, zN in and nZ out. +template +struct TileCopyTla< + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::B2>, + std::enable_if_t::value && tla::detail::iszN::value>> { + using Element = int8_t; + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterShapeCol = get<1, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = srcOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose(dstTensor.data()[i * dstOuterStrideRow], + srcTensor.data()[i * srcOuterStrideRow * 2], loadDataParams); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp new file mode 100644 index 00000000000..87d86e3bd40 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_UB_TO_GM_HPP +#define ACT_GEMM_TILE_COPY_UB_TO_GM_HPP + +#include "../../../act/act.hpp" +#include "../../../tla/tensor.hpp" + +namespace Act::Gemm::Tile { + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct TileCopyTla< + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::VECCALC>, + Tensor, LayoutDst_, AscendC::TPosition::GM>, + std::enable_if_t::value && tla::detail::isRowMajor::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::GM>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::VECCALC>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::DataCopyExtParams dataCopyParams( + get<0>(dstTensor.shape()), get<1>(dstTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, + (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) * sizeof(ElementSrc), 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); + }; +}; + +/// Partial specialization for AtlasA2, RowMajor in and PaddingRowMajor out. +template +struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::VECCALC>, + Tensor, LayoutDst_, AscendC::TPosition::GM>, layout::RowMajor, + layout::PaddingRowMajor> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::GM>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::VECCALC>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::DataCopyExtParams dataCopyParams( + get<1, 1>(dstTensor.shape()), get<1, 0>(dstTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, + (get<1, 1>(dstTensor.stride()) - get<1, 0>(dstTensor.shape())) * sizeof(ElementSrc), 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); + }; +}; + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_UB_TO_GM_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp new file mode 100644 index 00000000000..c7135709586 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_TILE_COPY_HPP +#define ACT_GEMM_TILE_TILE_COPY_HPP + +#include "../../../act/act.hpp" +#include "../../../act/detail/tag_to_layout.hpp" + +namespace Act::Gemm::Tile { + +template +struct TileCopyTla { + static_assert(DEPENDENT_FALSE, "Unsupporteded tileCopyTla, can not find the specialization."); +}; + +template +struct TileCopyTlaExt { + static_assert(DEPENDENT_FALSE, "Unsupporteded tileCopyTlaExt, can not find the specialization."); +}; +} // namespace Act::Gemm::Tile + +#include "../../../act/gemm/helper.hpp" +#include "../../../act/gemm/tile/copy_gm_to_l1.hpp" +#include "../../../act/gemm/tile/copy_gm_to_ub.hpp" +#include "../../../act/gemm/tile/copy_l0c_to_gm.hpp" +#include "../../../act/gemm/tile/copy_l1_to_l0a.hpp" +#include "../../../act/gemm/tile/copy_l1_to_l0b.hpp" +#include "../../../act/gemm/tile/copy_ub_to_gm.hpp" + +namespace Act::Gemm::Tile { + +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmType type for Bias operand + class BiasType = void> +struct TileCopy { + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A::L1AType>; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B::L1BType>; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; +}; + +/// new add +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmTpe type for Bias operand + class BiasType = void> +struct TileCopyGemm { + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + // change structural + using L1AType = typename helper::L1ATypeSelectorGemm::L1AType; + using L1BType = typename helper::L1BTypeSelectorGemm::L1BType; + using L0AType = typename helper::L0ATypeSelector::L0AType; + using L0BType = typename helper::L0BTypeSelectorGemm::L0BType; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; +}; + +template < + /// Tag indicating architecture + class ArchTag, class TensorA, class LayoutTagA, class TensorB, class LayoutTagB, class TensorC, class LayoutTagC, + class TensorBias = void, class LayoutTagBias = void> +struct PackedTileCopyTla { + using ElementA = typename TensorA::Element; + using ElementB = typename TensorB::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutL1A = + detail::TagToLayout_t>::L1AType::Layout>; + using LayoutL1B = + detail::TagToLayout_t>::L1BType::Layout>; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = Tensor, LayoutL1A, AscendC::TPosition::A1>; + using TensorL1B = Tensor, LayoutL1B, AscendC::TPosition::A1>; + using TensorL0A = Tensor, LayoutL0A, AscendC::TPosition::A2>; + using TensorL0B = Tensor, LayoutL0B, AscendC::TPosition::B2>; + using TensorL0C = Tensor, LayoutL0C, AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using CopyGmToL1A = Gemm::Tile::TileCopyTla; + using CopyGmToL1B = Gemm::Tile::TileCopyTla; + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; +}; + +template < + /// Tag indicating architecture + class ArchTag, class TensorA, class LayoutTagA, class TensorB, class LayoutTagB, class TensorC, class LayoutTagC, + class TensorBias = void, class LayoutTagBias = void, bool IS_PADDING_A = false, bool IS_PADDING_B = false> +struct PaddingPackedTileCopyTla { + static_assert(std::is_same_v || std::is_same_v, + "Unsupporteded layout, only can be RowMajor and ColumnMajor"); + static_assert(std::is_same_v || std::is_same_v, + "Unsupporteded layout, only can be RowMajor and ColumnMajor"); + using ElementA = typename TensorA::Element; + using ElementB = typename TensorB::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutTagL1A = typename helper::L1ATypeSelector>::L1AType::Layout; + using LayoutTagL1B = typename helper::L1BTypeSelector>::L1BType::Layout; + using LayoutL1A = detail::TagToLayout_t; + using LayoutL1B = detail::TagToLayout_t; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = Tensor, LayoutL1A, AscendC::TPosition::A1>; + using TensorL1B = Tensor, LayoutL1B, AscendC::TPosition::A1>; + using TensorL0A = Tensor, LayoutL0A, AscendC::TPosition::A2>; + using TensorL0B = Tensor, LayoutL0B, AscendC::TPosition::B2>; + using TensorL0C = Tensor, LayoutL0C, AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using LayoutPaddingTagA = std::conditional_t, layout::PaddingRowMajor, + layout::PaddingColumnMajor>; + using LayoutPaddingTagB = std::conditional_t, layout::PaddingRowMajor, + layout::PaddingColumnMajor>; + + using CopyGmToL1A = + std::conditional_t, + Gemm::Tile::TileCopyTla>; + using CopyGmToL1B = + std::conditional_t, + Gemm::Tile::TileCopyTla>; + + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; +}; +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_TILE_COPY_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp new file mode 100644 index 00000000000..7beacdf7d39 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_TILE_MMAD_HPP +#define ACT_GEMM_TILE_TILE_MMAD_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/helper.hpp" +namespace Act::Gemm::Tile { + +/////////////////////////////////////////////////////////// + +template < + /// Tag indicating architecture + class ArchTag_, + /// GemmType for A matrix operand + class AType_, + /// GemmType type for B matrix operand + class BType_, + /// GemmType type for Bias operand + class BiasType_> +struct TileMmad { + using ElementA = typename AType_::Element; + using ElementB = typename BType_::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + // Methods + + ACT_DEVICE + TileMmad() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &l0CTensor, + AscendC::LocalTensor const &l0ATensor, AscendC::LocalTensor const &l0BTensor, + uint32_t m, uint32_t n, uint32_t k, bool initC = true, uint8_t unitFlag = 0) + { + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + + AscendC::Mmad(l0CTensor, l0ATensor, l0BTensor, mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } + } +}; + +///////////////////////////////////////////TileMmadTla///////////////////////////////////////////////// + +template < + /// Tag indicating architecture + class ArchTag_, + /// Tensor type for A matrix operand + class TensorA, + /// Tensor type for B matrix operand + class TensorB, + /// Tensor type for C matrix operand + class TensorC, + /// Tensor type for Bias operand + class TensorBias = void> +struct TileMmadTla { + // Methods + + ACT_DEVICE + TileMmadTla() {} + + ACT_DEVICE + void operator()(TensorC const &l0CTensor, TensorA const &l0ATensor, TensorB const &l0BTensor, bool initC = true, + uint8_t unitFlag = 0) + { + const uint32_t m = get<0>(l0ATensor.orgShape()); + const uint32_t n = get<1>(l0BTensor.orgShape()); + const uint32_t k = get<1>(l0ATensor.orgShape()); + + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + + AscendC::Mmad(l0CTensor.data(), l0ATensor.data(), l0BTensor.data(), mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_TILE_MMAD_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm_coord.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemm_coord.hpp new file mode 100644 index 00000000000..2e8dbb56f75 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemm_coord.hpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_COORD_HPP +#define ACT_GEMM_COORD_HPP + +#include "../act/coord.hpp" + +namespace Act { + +/// Shape of a matrix multiply-add operation +template < + /// Rows of matrix product + uint32_t M_ = 1, + /// Columns of matrix product + uint32_t N_ = 1, + /// Inner dimension of matrix product + uint32_t K_ = 1> +struct GemmShape { + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; + static constexpr uint32_t K = K_; + + static constexpr int64_t MN = M * N; + static constexpr int64_t MK = M * K; + static constexpr int64_t KN = N * K; + static constexpr int64_t MNK = M * N * K; + + static constexpr int64_t COUNT = MNK; + + /// Returns a Coord object + ACT_HOST_DEVICE + static Coord<3> ToCoord() + { + return MakeCoord(M, N, K); + } + + ACT_HOST_DEVICE + static Coord<2> ToCoordMN() + { + return MakeCoord(M, N); + } + + ACT_HOST_DEVICE + static Coord<2> ToCoordMK() + { + return MakeCoord(M, K); + } + + ACT_HOST_DEVICE + static Coord<2> ToCoordKN() + { + return MakeCoord(K, N); + } +}; + +/// GemmCoord is a structure derived from Coord<3> that specifies a location +/// within the coordinate space of a Gemm problem. +struct GemmCoord : public Coord<3, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=3 + using Base = Coord<3, Index>; + + /// Gemm M dimension - rows of the output C matrix + static constexpr int M_INDEX = 0; + + /// Gemm N dimension - columns of the output C matrix + static constexpr int N_INDEX = 1; + + /// Gemm K dimension - inner dimension of the Gemm problem + static constexpr int K_INDEX = 2; + + /// Default ctor + ACT_HOST_DEVICE + GemmCoord() {} + + /// Constructs from Coord<3> and a batch + ACT_HOST_DEVICE + GemmCoord(Coord<3, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a K, N, M, batch variables + ACT_HOST_DEVICE + GemmCoord(Index m, Index n, Index k) : Base(MakeCoord(m, n, k)) {} + + /// Returns the Gemm M coordinate + ACT_HOST_DEVICE + Index const &m() const + { + return this->At(M_INDEX); + } + + /// Returns reference to the Gemm M coordinate + ACT_HOST_DEVICE + Index &m() + { + return this->At(M_INDEX); + } + + /// Returns the Gemm N coordinate + ACT_HOST_DEVICE + Index const &n() const + { + return this->At(N_INDEX); + } + + /// Returns reference to the Gemm N coordinate + ACT_HOST_DEVICE + Index &n() + { + return this->At(N_INDEX); + } + + /// Returns the Gemm K coordinate + ACT_HOST_DEVICE + Index const &k() const + { + return this->At(K_INDEX); + } + + /// Returns reference to the Gemm K coordinate + ACT_HOST_DEVICE + Index &k() + { + return this->At(K_INDEX); + } + + ACT_HOST_DEVICE + auto GetCoordMN() const + { + return this->GetCoordByAxis(); + } + + ACT_HOST_DEVICE + auto GetCoordMK() const + { + return this->GetCoordByAxis(); + } + + ACT_HOST_DEVICE + auto GetCoordKN() const + { + return this->GetCoordByAxis(); + } +}; + +} // namespace Act + +#endif // ACT_GEMM_COORD_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemv_coord.hpp b/csrc/utils/op_kernel/operator/catlass/act/gemv_coord.hpp new file mode 100644 index 00000000000..2e925c4a528 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/gemv_coord.hpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMV_COORD_HPP +#define ACT_GEMV_COORD_HPP + +#include "../act/coord.hpp" + +namespace Act { + +/// Shape of a matrix multiply-add operation +template < + /// Rows of matrix product + uint32_t M_ = 1, + /// Columns of the matrix (number of elements in the input vector) + uint32_t N_ = 1> +struct GemvShape { + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; + + static constexpr int64_t MN = M * N; + + static constexpr int64_t COUNT = MN; + + /// Returns a Coord object + ACT_HOST_DEVICE + static Coord<2> ToCoord() + { + return MakeCoord(M, N); + } +}; + +/// GemvCoord is a structure derived from Coord<2> that specifies a location +/// within the coordinate space of a GEMV problem. +struct GemvCoord : public Coord<2, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// GEMV M dimension - rows of the output vector (y) + static constexpr int M_INDEX = 0; + + /// GEMV N dimension - columns of the matrix (length of the input vector x) + static constexpr int N_INDEX = 1; + + /// Default ctor + ACT_HOST_DEVICE + GemvCoord() {} + + /// Constructs from Coord<2> and a batch + ACT_HOST_DEVICE + GemvCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from M, N coordinates + ACT_HOST_DEVICE + GemvCoord(Index m, Index n) : Base(MakeCoord(m, n)) {} + + /// Returns the GEMV M coordinate (row of the result y) + ACT_HOST_DEVICE + Index const &m() const + { + return this->At(M_INDEX); + } + + /// Returns reference to the GEMV M coordinate + ACT_HOST_DEVICE + Index &m() + { + return this->At(M_INDEX); + } + + /// Returns the GEMV N coordinate (column of the matrix A or the input vector + /// x) + ACT_HOST_DEVICE + Index const &n() const + { + return this->At(N_INDEX); + } + + /// Returns reference to the GEMV N coordinate + ACT_HOST_DEVICE + Index &n() + { + return this->At(N_INDEX); + } + + ACT_HOST_DEVICE + auto GetCoordMN() const + { + return this->GetCoordByAxis(); + } +}; + +} // namespace Act + +#endif // ACT_GEMV_COORD_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/layout/layout.hpp b/csrc/utils/op_kernel/operator/catlass/act/layout/layout.hpp new file mode 100644 index 00000000000..5282545ee33 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/layout/layout.hpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_LAYOUT_LAYOUT_HPP +#define ACT_LAYOUT_LAYOUT_HPP + +#include "../../act/act.hpp" +#include "../../act/layout/matrix.hpp" +#include "../../act/layout/vector.hpp" + +#endif // ACT_LAYOUT_LAYOUT_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/layout/matrix.hpp b/csrc/utils/op_kernel/operator/catlass/act/layout/matrix.hpp new file mode 100644 index 00000000000..be705ce0cef --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/layout/matrix.hpp @@ -0,0 +1,1184 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_LAYOUT_MATRIX_HPP +#define ACT_LAYOUT_MATRIX_HPP + +#include "../../act/act.hpp" +#include "../../act/coord.hpp" +#include "../../act/detail/alignment.hpp" +#include "../../act/matrix_coord.hpp" + +namespace Act::layout { + +/// Mapping function for row-major matrices +struct RowMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + ACT_HOST_DEVICE + RowMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(cols), LongIndex(1))) + {} + + /// Constructor + ACT_HOST_DEVICE + RowMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(ldm, LongIndex(1))) + {} + + /// Ctor + ACT_HOST_DEVICE + RowMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + ACT_HOST_DEVICE static RowMajor MakeLayoutInUb(MatrixCoord const &shape) + { + return RowMajor(shape.row(), shape.column(), RoundUp(shape.column())); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) * stride_[0] + LongIndex(coord.column()); + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + RowMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return RowMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + // + // Data members + // + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for col-major matrices +struct ColumnMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE + ColumnMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), LongIndex(rows))) + {} + + /// Constructor + ACT_HOST_DEVICE + ColumnMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), ldm)) + {} + + /// Ctor + ACT_HOST_DEVICE + ColumnMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) + LongIndex(coord.column()) * stride_[1]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + ColumnMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return ColumnMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + // + // Data members + // + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for nZ matrices which is col-major inside fractal and +/// row-major between fractal +struct nZ { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr nZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE constexpr nZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static nZ MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nZ(orgRows, orgCols, ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, + colsRound / C0_NUM_PER_FRACTAL, 1, colsRound * ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + nZ GetTileLayout(MatrixCoord const &tileOriShape) const + { + auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return nZ(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for zN matrices which is row-major inside fractal and +/// col-major between fractal +struct zN { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr zN( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE constexpr zN(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static zN MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zN(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL, 1, rowsRound * ELE_NUM_PER_C0); + } + + ACT_HOST_DEVICE + static zN MakeLayoutInL0C(MatrixCoord const &shape) + { + return zN(shape.row(), shape.column(), C0_NUM_PER_FRACTAL, CeilDiv(shape.row()), + C0_NUM_PER_FRACTAL, CeilDiv(shape.column()), C0_NUM_PER_FRACTAL, + C0_NUM_PER_FRACTAL * C0_NUM_PER_FRACTAL, 1, + RoundUp(shape.row()) * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + zN GetTileLayout(MatrixCoord const &tileOriShape) const + { + auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return zN(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for zN matrices which is row-major inside fractal and +/// row-major between fractal +struct zZ { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr zZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE constexpr zZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static zZ MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zZ(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, colsRound * C0_NUM_PER_FRACTAL, 1, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for padding rowmajor matrices +/// A special data layout designed to improve the efficiency of matrix +/// operations in non-512B aligned scenarios. This layout is row-major within +/// blocks and also row-major between blocks. +struct PaddingRowMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + ACT_HOST_DEVICE + PaddingRowMajor(Index orgRows, Index orgCols, Index blockRows, Index blockCols) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), + stride_(MakeCoord((LongIndex)blockCols, (LongIndex)blockRows * (LongIndex)RoundUp(orgCols, blockCols), + (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols)) + {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows * stride_[0] + (LongIndex)coord.column() % blockCols; + } + + ACT_HOST_DEVICE + PaddingRowMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return PaddingRowMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + // + // Data members + // + + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for padding columnmajor matrices +/// A special data layout designed to improve the efficiency of matrix +/// operations in non-512B aligned scenarios. This layout is column-major within +/// blocks and also column-major between blocks. +struct PaddingColumnMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + ACT_HOST_DEVICE + PaddingColumnMajor(Index orgRows, Index orgCols, Index blockRows, Index blockCols) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), + stride_(MakeCoord((LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols, (LongIndex)blockRows, + (LongIndex)RoundUp(orgRows, blockRows) * (LongIndex)blockCols)) + {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows + (LongIndex)coord.column() % blockCols * stride_[2]; + } + + ACT_HOST_DEVICE + PaddingColumnMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return PaddingColumnMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + // + // Data members + // + + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/////////////////////// +// new add layout nN +// nN layout +struct nN { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE + nN(Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE + nN(OrgShape orgShape, Shape shape, Stride stride) : orgShape_(orgShape), shape_(shape), stride_(stride) {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE static nN MakeLayout(Index orgRows, Index orgCols) + { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nN(orgRows, orgCols, + + ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, colsRound / C0_NUM_PER_FRACTAL, + + 1, ELE_NUM_PER_FRACTAL, ELE_NUM_PER_C0, rowsRound * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; +} // namespace Act::layout + +#endif // ACT_LAYOUT_MATRIX_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/layout/vector.hpp b/csrc/utils/op_kernel/operator/catlass/act/layout/vector.hpp new file mode 100644 index 00000000000..8b62f92a62d --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/layout/vector.hpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_LAYOUT_VECTOR_HPP +#define ACT_LAYOUT_VECTOR_HPP + +#include "../../act/act.hpp" +#include "../../act/coord.hpp" + +namespace Act::layout { + +struct VectorLayout { +public: + /// Logical rank of tensor + static constexpr int RANK = 1; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Shape vector + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + + /// Logical coordinate + using TensorCoord = Coord; + +public: + // Methods + + ACT_HOST_DEVICE + VectorLayout(Index size = 0) : shape_(MakeCoord(size)), stride_(MakeCoord(LongIndex(1))) {} + + ACT_HOST_DEVICE + VectorLayout(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + ACT_HOST_DEVICE static VectorLayout MakeLayoutInUb(TensorCoord const &tileShape) + { + return VectorLayout{RoundUp(tileShape[0])}; + } + + ACT_HOST_DEVICE + LongIndex GetOffset(TensorCoord const &coord) const + { + return stride_[0] * coord[0]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + VectorLayout GetTileLayout(TensorCoord const &tileShape) const + { + return VectorLayout(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Stride data member + Shape shape_; + Stride stride_; +}; + +} // namespace Act::layout + +#endif // ACT_LAYOUT_VECTOR_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/act/matrix_coord.hpp b/csrc/utils/op_kernel/operator/catlass/act/matrix_coord.hpp new file mode 100644 index 00000000000..a9018db48c3 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/act/matrix_coord.hpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_MATRIX_COORD_HPP +#define ACT_MATRIX_COORD_HPP + +#include "../act/coord.hpp" + +namespace Act { + +template +struct MatrixShape { + static constexpr uint32_t ROW = ROW_; + static constexpr uint32_t COLUMN = COLUMN_; + + static constexpr int64_t COUNT = ROW * COLUMN; + + ACT_HOST_DEVICE + static Coord<2> ToCoord() + { + return MakeCoord(ROW, COLUMN); + } +}; + +/// MatrixCoord wraps Coord<2, uint32_t> to provide a helper for accessing named +/// dimensions. Classes expecting a coordinate in the rank=2 index space of a +/// matrix should use MatrixCoord. +struct MatrixCoord : public Coord<2, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Rows dimension + static constexpr uint32_t ROW_INDEX = 0; + + /// Columns dimension + static constexpr uint32_t COLUMN_INDEX = 1; + + /// Default ctor + ACT_HOST_DEVICE + MatrixCoord() {} + + /// Constructs from Coord<2> + ACT_HOST_DEVICE + MatrixCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a row and column + ACT_HOST_DEVICE + MatrixCoord(Index row, Index column) : Base(MakeCoord(row, column)) {} + + /// Helper to construct from a row and column, which are LongIndex based + ACT_HOST_DEVICE + MatrixCoord(LongIndex row, LongIndex column) : Base(MakeCoord(Index(row), Index(column))) {} + + /// Returns the row of the coordinate + ACT_HOST_DEVICE + Index const &row() const + { + return this->At(ROW_INDEX); + } + + /// Returns the row of the coordinate + ACT_HOST_DEVICE + Index &row() + { + return this->At(ROW_INDEX); + } + + /// Returns the column of the coordinate + ACT_HOST_DEVICE + Index const &column() const + { + return this->At(COLUMN_INDEX); + } + + /// Returns the column of the coordinate + ACT_HOST_DEVICE + Index &column() + { + return this->At(COLUMN_INDEX); + } + + /// Element-wise addition + ACT_HOST_DEVICE + MatrixCoord operator+(Base const &b) const + { + return MatrixCoord(Base::operator+(b)); + } + + /// In-place addition + ACT_HOST_DEVICE + MatrixCoord &operator+=(Base const &b) + { + Base::operator+=(b); + return *this; + } +}; + +} // namespace Act + +#endif diff --git a/csrc/utils/op_kernel/operator/catlass/tla/int_tuple.hpp b/csrc/utils/op_kernel/operator/catlass/tla/int_tuple.hpp new file mode 100644 index 00000000000..f0702b3f9f1 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/int_tuple.hpp @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_INT_TUPLE_HPP +#define TLA_INT_TUPLE_HPP + +#include "../tla/type_traits.hpp" +#include "../tla/tuple.hpp" +#include "../tla/numeric/integral_constant.hpp" +#include "../tla/numeric/integer_sequence.hpp" + +namespace tla { +// +// Apply (Unpack) +// (t, f) => f(t_0,t_1,...,t_n) +// + +namespace detail { +template +ACT_HOST_DEVICE constexpr auto apply(T &&t, F &&f, seq) +{ + return f(get(static_cast(t))...); +} + +template +ACT_HOST_DEVICE constexpr auto tapply(T &&t, F &&f, G &&g, seq) +{ + return g(f(get(static_cast(t)))...); +} + +} // end namespace detail + +template +ACT_HOST_DEVICE constexpr auto apply(T &&t, F &&f) +{ + return detail::apply(static_cast(t), f, tuple_seq{}); +} + +template +ACT_HOST_DEVICE constexpr auto transform_apply(T &&t, F &&f, G &&g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t))); + } +} + +template >::value)> +ACT_HOST_DEVICE constexpr decltype(auto) get(T &&t) noexcept +{ + static_assert(I == 0, "Index out of range"); + return static_cast(t); +} + +template +ACT_HOST_DEVICE constexpr decltype(auto) get(T &&t) noexcept +{ + return get(get(static_cast(t))); +} + +// max +template +ACT_HOST_DEVICE constexpr auto max(T0 const &t0, Ts const &...ts); + +struct UnpackedMax { + template + ACT_HOST_DEVICE constexpr auto operator()(T const &...v) const + { + return tla::max(v...); + } +}; + +template +ACT_HOST_DEVICE constexpr auto max(T0 const &t0, Ts const &...ts) +{ + if constexpr (is_tuple::value) { + return tla::max(tla::apply(t0, UnpackedMax{}), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return tla::max(t0, tla::max(ts...)); + } +} + +// rank +template +ACT_HOST_DEVICE constexpr auto rank(Tuple const &t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int::value>{}; + } else { + return Int<1>{}; + } + } else { + return rank(get(t)); + } +} + +template +using rank_t = decltype(rank(std::declval())); + +template +static constexpr auto rank_v = rank_t::value; + +// depth +template +ACT_HOST_DEVICE constexpr auto depth(Tuple const &t); + +struct UnpackedDepth { + template + ACT_HOST_DEVICE constexpr auto operator()(T const &...v) const + { + return tla::max(depth(v)...); + } +}; + +template +ACT_HOST_DEVICE constexpr auto depth(Tuple const &t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int<1>{} + tla::apply(t, UnpackedDepth{}); + } else { + return Int<0>{}; + } + } else { + return depth(get(t)); + } +} + +template +using depth_t = decltype(depth(std::declval())); + +template +static constexpr auto depth_v = depth_t::value; + +struct MultipliesUnaryLfold { + template + ACT_HOST_DEVICE constexpr auto operator()(T const &...v) const + { + return (... * v); + } +}; + +// Implementation of product as a function object +struct Product { + template + ACT_HOST_DEVICE constexpr auto operator()(IntTuple const &a) const + { + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 0) { + return Int<1>{}; + } else { + return tla::transform_apply(a, Product{}, MultipliesUnaryLfold{}); + } + } else if constexpr (tla::is_integral::value) { + return a; + } + } +}; + +} // end namespace tla + +#endif // TLA_INT_TUPLE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/tla/layout.hpp b/csrc/utils/op_kernel/operator/catlass/tla/layout.hpp new file mode 100644 index 00000000000..8f345bf7483 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/layout.hpp @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_LAYOUT_HPP +#define TLA_LAYOUT_HPP + +#include "../act/act.hpp" +#include "../tla/numeric/integral_constant.hpp" +#include "../tla/tuple.hpp" +#include "../tla/int_tuple.hpp" + +using namespace Act; + +namespace tla { + +// Aliases + +template +using Shape = tla::tuple; + +template +using Stride = tla::tuple; + +template +using Coord = tla::tuple; + +template +ACT_HOST_DEVICE constexpr Shape MakeShape(Ts const &...t) +{ + return {t...}; +} +template +ACT_HOST_DEVICE constexpr Stride MakeStride(Ts const &...t) +{ + return {t...}; +} +template +ACT_HOST_DEVICE constexpr Coord MakeCoord(Ts const &...t) +{ + return {t...}; +} + +// +// Layout +// + +template +struct Layout : private tla::tuple { + // NOTE: This defaults static Shapes/Strides correctly, but not dynamic + ACT_HOST_DEVICE constexpr Layout(Shape const &shape = {}, Stride const &stride = {}, OrgShape const &orgShape = {}) + : tla::tuple(shape, stride, orgShape) + {} + + // + // Accessors + // + + static constexpr int rank = rank_v; + static constexpr int depth = depth_v; + + template + ACT_HOST_DEVICE constexpr decltype(auto) shape() + { + return get<0, I...>(static_cast &>(*this)); + } + + template + ACT_HOST_DEVICE constexpr decltype(auto) shape() const + { + return get<0, I...>(static_cast const &>(*this)); + } + + template + ACT_HOST_DEVICE constexpr decltype(auto) stride() + { + return get<1, I...>(static_cast &>(*this)); + } + + template + ACT_HOST_DEVICE constexpr decltype(auto) stride() const + { + return get<1, I...>(static_cast const &>(*this)); + } + + template + ACT_HOST_DEVICE constexpr decltype(auto) orgShape() + { + return get<2, I...>(static_cast &>(*this)); + } + + template + ACT_HOST_DEVICE constexpr decltype(auto) orgShape() const + { + return get<2, I...>(static_cast const &>(*this)); + } + + template + ACT_HOST_DEVICE constexpr auto operator()(Coord const &coord) const + { + return crd2idx(coord, shape(), stride()); + } +}; + +// Layout construction + +template +ACT_HOST_DEVICE constexpr auto MakeLayout(Shape const &shape, Stride const &stride, OrgShape const &orgShape) +{ + static_assert(is_tuple::value || is_integral::value); + static_assert(is_tuple::value || is_integral::value); + static_assert(depth_v == 1 && rank_v == rank_v); + return Layout(shape, stride, orgShape); +} + +struct UnpackedMakeShape { + template + ACT_HOST_DEVICE constexpr Shape operator()(T const &...v) const + { + return {v...}; + } +}; + +template +ACT_HOST_DEVICE constexpr auto MakeLayout(Shape const &shape, Stride const &stride) +{ + static_assert(is_tuple::value || is_integral::value); + static_assert(is_tuple::value || is_integral::value); + auto orgShape = tla::transform_apply(shape, Product{}, UnpackedMakeShape{}); + return MakeLayout(shape, stride, orgShape); +} + +// Convenience tags for common layouts + +template +ACT_HOST_DEVICE constexpr auto MakeLayoutFromTag(LayoutTag const &tag) +{ + static_assert(std::is_same_v || std::is_same_v, + "Unsupported LayoutTag for MakeLayoutFromTag, only support layout::RowMajor or layout::ColumnMajor"); + + if constexpr (std::is_same_v) { + return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)), MakeStride(tag.stride(0), Int<1>{})); + } else { + return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)), MakeStride(Int<1>{}, tag.stride(1))); + } +} + +// Return the shape of a mode +template +ACT_HOST_DEVICE constexpr decltype(auto) shape(Layout &layout) +{ + return layout.template shape(); +} + +template +ACT_HOST_DEVICE constexpr decltype(auto) shape(Layout const &layout) +{ + return layout.template shape(); +} + +// Return the stride of a mode +template +ACT_HOST_DEVICE constexpr decltype(auto) stride(Layout &layout) +{ + return layout.template stride(); +} + +template +ACT_HOST_DEVICE constexpr decltype(auto) stride(Layout const &layout) +{ + return layout.template stride(); +} + +// Return the orgShape of a mode +template +ACT_HOST_DEVICE constexpr decltype(auto) orgShape(Layout &layout) +{ + return layout.template orgShape(); +} + +template +ACT_HOST_DEVICE constexpr decltype(auto) orgShape(Layout const &layout) +{ + return layout.template orgShape(); +} + +// Return the rank of layout +template +ACT_HOST_DEVICE constexpr auto rank(Layout const &layout) +{ + return rank(shape(layout)); +} + +// Return the depth of the layout +template +ACT_HOST_DEVICE constexpr auto depth(Layout const &layout) +{ + return depth(shape(layout)); +} + +// Return the offset of coord +template +ACT_HOST_DEVICE constexpr auto crd2idx(Coord const &coord, Shape const &shape, Stride const &stride) +{ + static_assert(is_tuple::value && depth_v == 1 && rank_v == 2); + + constexpr int strideDepth = depth_v; + const uint32_t row = get<0>(coord); + const uint32_t col = get<1>(coord); + if constexpr (strideDepth == 1) { + const int64_t rowStride = get<0>(stride); + const int64_t colStride = get<1>(stride); + return row * rowStride + col * colStride; + } else if constexpr (strideDepth == 2) { + const uint32_t rowsInFractal = get<0, 0>(shape); + const uint32_t colsInFractal = get<1, 0>(shape); + const int64_t strideRowsByFractal = get<0, 1>(stride); + const int64_t strideColsByFractal = get<1, 1>(stride); + return row / rowsInFractal * strideRowsByFractal + col / colsInFractal * strideColsByFractal + + (row % rowsInFractal) * get<0, 0>(stride) + (col % colsInFractal) * get<1, 0>(stride); + } +} + +template +struct is_layout : false_type {}; +template +struct is_layout> : true_type {}; + +namespace detail { + +template +struct isRowMajor { + static bool const value = false; +}; + +template +struct isRowMajor> { + static bool const value = (stride<1>(Layout{}) == 1); +}; + +template +struct isColumnMajor { + static bool const value = false; +}; + +template +struct isColumnMajor> { + static bool const value = (stride<0>(Layout{}) == 1); +}; + +template +struct iszN { + static bool const value = false; +}; + +template +struct iszN> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + static bool const value = (shape<0, 0>(Layout{}) == C0_NUM_PER_FRACTAL && shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 && + stride<1, 0>(Layout{}) == 1 && stride<0, 1>(Layout{}) == ELE_NUM_PER_FRACTAL); +}; + +template +struct iszZ { + static bool const value = false; +}; + +template +struct iszZ> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + static bool const value = (shape<0, 0>(Layout{}) == C0_NUM_PER_FRACTAL && shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 && + stride<1, 0>(Layout{}) == 1 && stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL); +}; + +template +struct isnZ { + static bool const value = false; +}; + +template +struct isnZ> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + static bool const value = (shape<0, 0>(Layout{}) == ELE_NUM_PER_C0 && shape<1, 0>(Layout{}) == C0_NUM_PER_FRACTAL && + stride<0, 0>(Layout{}) == 1 && stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL); +}; + +} // end namespace detail + +// Advanced Layout constructions +// Make a inner layout with Rows and Cols. +template +ACT_HOST_DEVICE constexpr auto MakeLayout(uint32_t const &rows, uint32_t const &cols) +{ + static_assert(detail::iszN::value || detail::iszZ::value || + detail::isnZ::value, + "Unsupported Layout for MakeLayout, only support zN or zZ or nZ"); + + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + if constexpr (detail::iszN::value) { + return MakeLayout(MakeShape(MakeShape(Int{}, CeilDiv(rows)), + MakeShape(Int{}, CeilDiv(cols))), + MakeStride(MakeStride(Int{}, Int{}), + MakeStride(Int<1>{}, (int64_t)RoundUp(rows) * ELE_NUM_PER_C0)), + MakeShape(rows, cols)); + } else if constexpr (detail::iszZ::value) { + return MakeLayout( + MakeShape(MakeShape(Int{}, CeilDiv(rows)), + MakeShape(Int{}, CeilDiv(cols))), + MakeStride(MakeStride(Int{}, (int64_t)RoundUp(cols) * C0_NUM_PER_FRACTAL), + MakeStride(Int<1>{}, Int{})), + MakeShape(rows, cols)); + } else { + return MakeLayout(MakeShape(MakeShape(Int{}, CeilDiv(rows)), + MakeShape(Int{}, CeilDiv(cols))), + MakeStride(MakeStride(Int<1>{}, (int64_t)RoundUp(cols) * ELE_NUM_PER_C0), + MakeStride(Int{}, Int{})), + MakeShape(rows, cols)); + } +} + +template +ACT_HOST_DEVICE constexpr auto MakeLayoutTile(Layout const &layout, ShapeNew const &shapeNew) +{ + static_assert(is_tuple::value && depth_v == 1 && rank_v == 2); + + if constexpr (Layout::depth == 1 && Layout::rank == 2) { + return MakeLayout(shapeNew, layout.stride()); + } else if constexpr (is_integral(layout))>::value && + is_integral(layout))>::value) { + const uint32_t rows = get<0>(shapeNew); + const uint32_t cols = get<1>(shapeNew); + constexpr uint32_t dstInnerShapeRow = decltype(shape<0, 0>(layout))::value; + constexpr uint32_t dstInnerShapeCol = decltype(shape<1, 0>(layout))::value; + return MakeLayout(MakeShape(MakeShape(Int{}, CeilDiv(rows)), + MakeShape(Int{}, CeilDiv(cols))), + layout.stride(), shapeNew); + } else { + const uint32_t rows = get<0>(shapeNew); + const uint32_t cols = get<1>(shapeNew); + const uint32_t dstInnerShapeRow = shape<0, 0>(layout); + const uint32_t dstInnerShapeCol = shape<1, 0>(layout); + return MakeLayout(MakeShape(MakeShape(dstInnerShapeRow, CeilDiv(rows, dstInnerShapeRow)), + MakeShape(dstInnerShapeCol, CeilDiv(cols, dstInnerShapeCol))), + layout.stride(), shapeNew); + } +} + +ACT_HOST_DEVICE constexpr auto MakeLayoutL0C(uint32_t const &rows, uint32_t const &cols) +{ + constexpr uint32_t ELE_NUM_PER_FRACTAL = 256; + return MakeLayout(MakeShape(MakeShape(Int{}, CeilDiv(rows)), + MakeShape(Int{}, CeilDiv(cols))), + MakeStride(MakeStride(Int{}, Int{}), + MakeStride(Int<1>{}, (int64_t)RoundUp(rows) * C0_NUM_PER_FRACTAL)), + MakeShape(rows, cols)); +} + +} // end namespace tla + +#endif // TLA_LAYOUT_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp b/csrc/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp new file mode 100644 index 00000000000..53be0b4f4b1 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_NUMERIC_INTEGER_SEQUENCE_HPP +#define TLA_NUMERIC_INTEGER_SEQUENCE_HPP + +#include "../../tla/numeric/integral_constant.hpp" +#include "../../tla/type_traits.hpp" + +namespace tla { + +template +struct IntegerSequence { + using value_type = T; + static constexpr size_t size() + { + return sizeof...(Ns); + } +}; + +template +struct MakeIntegerSequenceImpl; + +template +struct MakeIntegerSequenceImpl, T, 0> { + typedef IntegerSequence type; +}; + +template +struct MakeIntegerSequenceImpl, T, N> { + typedef typename MakeIntegerSequenceImpl, T, N - 1>::type type; +}; + +template +using MakeIntegerSequence = typename MakeIntegerSequenceImpl, T, N>::type; + +// index_sequence +template +using index_sequence = IntegerSequence; + +template +using make_index_sequence = MakeIntegerSequence; + +// int_sequence +template +using int_sequence = IntegerSequence; + +template +using make_int_sequence = MakeIntegerSequence; + +// Shortcuts +template +using seq = int_sequence; + +template +using make_seq = make_int_sequence; + +template +using tuple_seq = make_seq>::value>; + +} // end namespace tla + +#endif // TLA_NUMERIC_INTEGER_SEQUENCE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp b/csrc/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp new file mode 100644 index 00000000000..d6a28116578 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_NUMERIC_INTEGER_CONSTANT_HPP +#define TLA_NUMERIC_INTEGER_CONSTANT_HPP + +#include "../../act/act.hpp" +#include "../../tla/type_traits.hpp" +#include "../../tla/numeric/math.hpp" + +namespace tla { + +// A constant value: short name and type-deduction for fast compilation +template +struct C { + using type = C; + static constexpr auto value = v; + using value_type = decltype(v); + ACT_HOST_DEVICE constexpr operator value_type() const noexcept + { + return value; + } + ACT_HOST_DEVICE constexpr value_type operator()() const noexcept + { + return value; + } +}; + +// Deprecate +template +using constant = C; + +template +using bool_constant = C; + +using true_type = bool_constant; +using false_type = bool_constant; + +template +using is_std_integral = std::is_integral; + +// A more std:: conforming integral_constant that enforces type but interops with C +template +struct integral_constant : C { + using type = integral_constant; + static constexpr T value = v; + using value_type = T; + ACT_HOST_DEVICE constexpr value_type operator()() const noexcept + { + return value; + } +}; + +// Use tla::is_std_integral to match built-in integral types (int, int64_t, unsigned, etc) +// Use tla::is_integral to match both built-in integral types AND static integral types. + +template +struct is_integral : bool_constant::value> {}; +template +struct is_integral> : true_type {}; +template +struct is_integral> : true_type {}; + +// is_static detects if an (abstract) value is defined completely by its type (no members) +template +struct is_static : bool_constant>::value> {}; + +// is_constant detects if a type is a static integral type and if v is equal to a value + +template +struct is_constant : false_type {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant> : bool_constant {}; +template +struct is_constant> : bool_constant {}; + +// +// Specializations +// + +template +using Int = C; +using _64 = Int<64>; +using _128 = Int<128>; +using _256 = Int<256>; +using _512 = Int<512>; + +/***************/ +/** Operators **/ +/***************/ + +#define TLA_LEFT_UNARY_OP(OP) \ + template \ + ACT_HOST_DEVICE constexpr C<(OP t)> operator OP(C) \ + { \ + return {}; \ + } +#define TLA_BINARY_OP(OP) \ + template \ + ACT_HOST_DEVICE constexpr C<(t OP u)> operator OP(C, C) \ + { \ + return {}; \ + } + +TLA_LEFT_UNARY_OP(+); +TLA_LEFT_UNARY_OP(-); +TLA_LEFT_UNARY_OP(~); +TLA_LEFT_UNARY_OP(!); +TLA_LEFT_UNARY_OP(*); + +TLA_BINARY_OP(+); +TLA_BINARY_OP(-); +TLA_BINARY_OP(*); +TLA_BINARY_OP(/); +TLA_BINARY_OP(%); +TLA_BINARY_OP(&); +TLA_BINARY_OP(|); +TLA_BINARY_OP(^); +TLA_BINARY_OP(<<); +TLA_BINARY_OP(>>); + +#undef TLA_BINARY_OP +#undef TLA_LEFT_UNARY_OP +#undef TLA_RIGHT_UNARY_OP + +// +// Named functions from math.hpp +// + +#define TLA_NAMED_UNARY_FN(OP) \ + template \ + ACT_HOST_DEVICE constexpr auto OP(C) \ + { \ + return C{}; \ + } +#define TLA_NAMED_BINARY_FN(OP) \ + template \ + ACT_HOST_DEVICE constexpr auto OP(C, C) \ + { \ + return C{}; \ + } \ + template ::value)> \ + ACT_HOST_DEVICE constexpr auto OP(C, U u) \ + { \ + return OP(t, u); \ + } \ + template ::value)> \ + ACT_HOST_DEVICE constexpr auto OP(T t, C) \ + { \ + return OP(t, u); \ + } + +TLA_NAMED_BINARY_FN(max); +TLA_NAMED_BINARY_FN(min); + +#undef TLA_NAMED_UNARY_FN +#undef TLA_NAMED_BINARY_FN + +} // end namespace tla + +#endif // TLA_NUMERIC_INTEGER_CONSTANT_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/tla/numeric/math.hpp b/csrc/utils/op_kernel/operator/catlass/tla/numeric/math.hpp new file mode 100644 index 00000000000..94a0982d42a --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/numeric/math.hpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_NUMERIC_MATH_HPP +#define TLA_NUMERIC_MATH_HPP + +#include "../../tla/type_traits.hpp" + +namespace tla { + +// +// Common Operations +// + +template ::value &&std::is_arithmetic::value)> +ACT_HOST_DEVICE constexpr auto max(T const &t, U const &u) +{ + return t < u ? u : t; +} + +template ::value &&std::is_arithmetic::value)> +ACT_HOST_DEVICE constexpr auto min(T const &t, U const &u) +{ + return t < u ? t : u; +} + +} // namespace tla + +#endif // TLA_NUMERIC_MATH_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/tla/tensor.hpp b/csrc/utils/op_kernel/operator/catlass/tla/tensor.hpp new file mode 100644 index 00000000000..3ce4b50bbee --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/tensor.hpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_TENSOR_HPP +#define TLA_TENSOR_HPP + +#include "../tla/layout.hpp" // tla::Shape +#include "../tla/numeric/integral_constant.hpp" // tla::is_integral + +namespace tla { +// +// Tensor +// + +template +struct Tensor { + using Element = typename BuiltinTensor::PrimType; + using Layout = Layout_; + static constexpr AscendC::TPosition position = Position; + + ACT_HOST_DEVICE constexpr Tensor() {} + + ACT_HOST_DEVICE constexpr Tensor(BuiltinTensor const &builtinTensor, Layout const &layout) + : rep_(builtinTensor, layout) + {} + + // + // Accessors + // + + static constexpr int rank = Layout::rank; + + ACT_HOST_DEVICE constexpr decltype(auto) tensor() const + { + return *this; + } + + ACT_HOST_DEVICE constexpr decltype(auto) data() const + { + return get<0>(rep_); + } + + ACT_HOST_DEVICE constexpr decltype(auto) data() + { + return get<0>(rep_); + } + + ACT_HOST_DEVICE constexpr decltype(auto) layout() const + { + return get<1>(rep_); + } + + ACT_HOST_DEVICE constexpr decltype(auto) shape() const + { + return layout().shape(); + } + + ACT_HOST_DEVICE constexpr decltype(auto) stride() const + { + return layout().stride(); + } + + ACT_HOST_DEVICE constexpr decltype(auto) orgShape() const + { + return layout().orgShape(); + } + + tla::tuple rep_; +}; + +template +ACT_HOST_DEVICE constexpr auto MakeTensor(BuiltinTensor const &builtinTensor, Layout const &layout) +{ + return Tensor(builtinTensor, layout); +} + +template +ACT_HOST_DEVICE constexpr auto MakeTensor(BuiltinTensor const &builtinTensor, Layout const &layout, PositionType) +{ + return Tensor(builtinTensor, layout); +} + +template +ACT_DEVICE constexpr auto GetTile(Tensor const &tensor, Coord const &coord, Shape const &shape) +{ + auto layout = tensor.layout(); + auto offset = layout(coord); + auto builtinTensor = tensor.data(); + auto layoutNew = MakeLayoutTile(layout, shape); + return MakeTensor(builtinTensor[offset], layoutNew); +} + +} // end namespace tla + +#endif // TLA_TENSOR_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/tla/tuple.hpp b/csrc/utils/op_kernel/operator/catlass/tla/tuple.hpp new file mode 100644 index 00000000000..105a2c0ecc8 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/tuple.hpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_TUPLE_HPP +#define TLA_TUPLE_HPP + +#include "../tla/numeric/integral_constant.hpp" +#include "../tla/numeric/integer_sequence.hpp" + +namespace tla { + +namespace detail { + +// EBO stands for "empty base optimization." +template ::value> +struct EBO; + +// Specialization for types T that are empty; +template +struct EBO { + ACT_HOST_DEVICE constexpr EBO() {} + + ACT_HOST_DEVICE constexpr EBO(T const &) {} +}; + +template +ACT_HOST_DEVICE constexpr T getv(EBO const &) +{ + return {}; +} + +// Specialization for types T that are not empty; +template +struct EBO { + ACT_HOST_DEVICE constexpr EBO() : t_{} {} + + ACT_HOST_DEVICE constexpr EBO(T const &t) : t_{t} {} + + T t_; +}; + +template +ACT_HOST_DEVICE constexpr T const &getv(EBO const &x) +{ + return x.t_; +} + +template +ACT_HOST_DEVICE constexpr T &getv(EBO &x) +{ + return x.t_; +} + +// TupleBase +template +struct TupleBase; + +template +struct TupleBase, T...> : EBO... { + ACT_HOST_DEVICE constexpr TupleBase() {} + + ACT_HOST_DEVICE constexpr TupleBase(T const &...t) : EBO(t)... {} +}; + +} // end namespace detail + +// tla::tuple class. +template +struct tuple : detail::TupleBase, T...> { + ACT_HOST_DEVICE constexpr tuple() {} + + ACT_HOST_DEVICE constexpr tuple(T const &...t) : detail::TupleBase, T...>(t...) {} +}; + +// get for tla::tuple +template +ACT_HOST_DEVICE constexpr decltype(auto) get(tuple const &t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +ACT_HOST_DEVICE constexpr decltype(auto) get(tuple &t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +ACT_HOST_DEVICE constexpr decltype(auto) get(tuple &&t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(static_cast &&>(t)); +} + +namespace detail { + +template +auto has_tuple_size(T *) -> bool_constant<(0 <= tuple_size::value)>; +auto has_tuple_size(...) -> false_type; + +} // end namespace detail + +template +struct is_tuple : decltype(detail::has_tuple_size((T *)0)){}; + +template +struct tuple_size> : std::integral_constant {}; + +template +struct tuple_size> : std::integral_constant {}; + +} // end namespace tla + +#endif // TLA_TUPLE_HPP diff --git a/csrc/utils/op_kernel/operator/catlass/tla/type_traits.hpp b/csrc/utils/op_kernel/operator/catlass/tla/type_traits.hpp new file mode 100644 index 00000000000..7bbd9e37f85 --- /dev/null +++ b/csrc/utils/op_kernel/operator/catlass/tla/type_traits.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_UTIL_TYPE_TRAITS_HPP +#define TLA_UTIL_TYPE_TRAITS_HPP + +#undef inline +#include +#define inline __inline__ __attribute__((always_inline)) + +#define __TLA_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type * = nullptr + +namespace tla { + +// using std::remove_cvref; +template +struct remove_cvref { + using type = std::remove_cv_t>; +}; + +// using std::remove_cvref_t; +template +using remove_cvref_t = typename remove_cvref::type; + +// tuple_size, tuple_element +template +struct tuple_size; + +template +struct tuple_size::type>> + : std::integral_constant::value> {}; + +template +constexpr size_t tuple_size_v = tuple_size::value; + +} // end namespace tla + +#endif // TLA_UTIL_TYPE_TRAITS_HPP diff --git a/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue.h b/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue.h new file mode 100644 index 00000000000..8bcc86513e4 --- /dev/null +++ b/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/act/epilogue/block/block_epilogue.hpp" + +#include "block_epilogue_per_token_dequant_swiglu.h" diff --git a/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h b/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h new file mode 100644 index 00000000000..972e1628e6a --- /dev/null +++ b/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h @@ -0,0 +1,326 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/act/act.hpp" +#include "../../catlass/act/arch/resource.hpp" +#include "../../catlass/act/epilogue/dispatch_policy.hpp" +#include "../../catlass/act/gemm_coord.hpp" +#include "../../catlass/act/matrix_coord.hpp" +#include "../../catlass/act/layout/layout.hpp" +#include "../../catlass/act/detail/callback.hpp" + +#include "../../epilogue/tile/tile_stride_muls.h" +#include "../../epilogue/tile/tile_stride_binary.h" + +namespace Act::Epilogue::Block { + +template +class BlockEpilogue, CType_, + Gemm::GemmType, Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0, + "The per token scale granularity for word calculation must be 32 bytes aligned."); + static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts."); + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2; + using ChunkTileShape = MatrixShape; + + using TileStrideMuls = Tile::TileStrideMuls; + using TileStrideDiv = Tile::TileStrideDiv; + using TileStrideMul = Tile::TileStrideMul; + + static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + ACT_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubTmpMxN = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubTmpMx32B = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte(ubOffset); + } + + ACT_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + ACT_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + ACT_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (0 == actualBlockShapeMNK.k()) { + return; + } + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto ubChunkTileStride = MakeCoord(static_cast(ChunkTileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = 0; // 原本是AscendC::GetSubBlockIdx(); + uint32_t subblockNum = 1; // 原本是AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1); + auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1); + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B); + AscendC::PipeBarrier(); + tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f); + AscendC::PipeBarrier(); + AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT); + AscendC::PipeBarrier(); + tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN); + AscendC::PipeBarrier(); + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride}; + + auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN]; + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubTmpMxN; + AscendC::LocalTensor ubTmpMx32B; + AscendC::LocalTensor ubTmpMxChunkN; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + TileStrideMuls tileStrideMuls; + TileStrideDiv tileStrideDiv; + TileStrideMul tileStrideMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Act::Epilogue::Block diff --git a/csrc/utils/op_kernel/operator/epilogue/dispatch_policy.h b/csrc/utils/op_kernel/operator/epilogue/dispatch_policy.h new file mode 100644 index 00000000000..3d2d5e86aa9 --- /dev/null +++ b/csrc/utils/op_kernel/operator/epilogue/dispatch_policy.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../catlass/act/epilogue/dispatch_policy.hpp" + +namespace Act::Epilogue { + +template +struct EpilogueAtlasA2PerTokenDequantSwiglu { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; + +} // namespace Act::Epilogue diff --git a/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h b/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h new file mode 100644 index 00000000000..ede8e1122ce --- /dev/null +++ b/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/act/act.hpp" + +namespace Act::Epilogue::Tile { + +template +struct TileStrideBinary { + using ArchTag = ArchTag_; + using ElementCompute = ElementCompute_; + using TileShape = TileShape_; + static constexpr int64_t DST_STRIDE = DST_STRIDE_; + static constexpr int64_t SRC0_STRIDE = SRC0_STRIDE_; + static constexpr int64_t SRC1_STRIDE = SRC1_STRIDE_; + + static constexpr uint32_t MAX_REPEAT_TIMES = 255; + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementCompute); + + static constexpr uint32_t DST_BLK_NUM_PER_COLUMN = DST_STRIDE / ELE_NUM_PER_BLK; + static constexpr uint32_t SRC0_BLK_NUM_PER_COLUMN = SRC0_STRIDE / ELE_NUM_PER_BLK; + static constexpr uint32_t SRC1_BLK_NUM_PER_COLUMN = SRC1_STRIDE / ELE_NUM_PER_BLK; + + static constexpr uint32_t ROW_NUM_PER_COMPUTE = MAX_REPEAT_TIMES; + static constexpr uint32_t COL_NUM_PER_COMPUTE = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + + ACT_DEVICE + TileStrideBinary() + { + repeatParams.dstBlkStride = 1; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = DST_BLK_NUM_PER_COLUMN; + repeatParams.src0RepStride = SRC0_BLK_NUM_PER_COLUMN; + repeatParams.src1RepStride = SRC1_BLK_NUM_PER_COLUMN; + } + + AscendC::BinaryRepeatParams repeatParams; +}; + +template +struct TileStrideMul + : TileStrideBinary { + using Base = TileStrideBinary; + + ACT_DEVICE + TileStrideMul() : Base() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc0, + AscendC::LocalTensor const &ubSrc1) + { + for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) { + uint32_t residueM = Base::TileShape::ROW - rowOffset; + uint8_t repeatTimes = + static_cast((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM); + for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) { + uint32_t residueN = Base::TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN; + AscendC::Mul(ubDst[rowOffset * Base::DST_STRIDE + colOffset], + ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset], + ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams); + } + } + } +}; + +template +struct TileStrideDiv + : TileStrideBinary { + using Base = TileStrideBinary; + + ACT_DEVICE + TileStrideDiv() : Base() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc0, + AscendC::LocalTensor const &ubSrc1) + { + for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) { + uint32_t residueM = Base::TileShape::ROW - rowOffset; + uint8_t repeatTimes = + static_cast((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM); + for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) { + uint32_t residueN = Base::TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN; + AscendC::Div(ubDst[rowOffset * Base::DST_STRIDE + colOffset], + ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset], + ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams); + } + } + } +}; + +} // namespace Act::Epilogue::Tile diff --git a/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h b/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h new file mode 100644 index 00000000000..15fb71e2b3d --- /dev/null +++ b/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/act/act.hpp" + +namespace Act::Epilogue::Tile { + +template +struct TileStrideMuls { + using ArchTag = ArchTag_; + using ElementCompute = ElementCompute_; + using TileShape = TileShape_; + using DstTileShape = DstTileShape_; + using SrcTileShape = SrcTileShape_; + + static_assert(DstTileShape::ROW == SrcTileShape::ROW && DstTileShape::ROW == TileShape::ROW, "Error"); + + ACT_DEVICE + TileStrideMuls() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc, ElementCompute scalar) + { + constexpr uint32_t maxRepeatTimes = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t dstBlkNumPerColumn = DstTileShape::COLUMN / eleNumPerBlk; + constexpr uint32_t srcBlkNumPerColumn = SrcTileShape::COLUMN / eleNumPerBlk; + AscendC::UnaryRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.srcBlkStride = 1; + repeatParams.dstRepStride = dstBlkNumPerColumn; + repeatParams.srcRepStride = srcBlkNumPerColumn; + + constexpr uint32_t rowNumPerCompute = maxRepeatTimes; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint8_t repeatTimes = static_cast((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN; + AscendC::Muls(ubDst[rowOffset * DstTileShape::COLUMN + colOffset], + ubSrc[rowOffset * SrcTileShape::COLUMN + colOffset], scalar, mask, repeatTimes, + repeatParams); + } + } + } +}; + +} // namespace Act::Epilogue::Tile diff --git a/csrc/utils/op_kernel/operator/gemm/block/block_mmad.h b/csrc/utils/op_kernel/operator/gemm/block/block_mmad.h new file mode 100644 index 00000000000..42f684808fc --- /dev/null +++ b/csrc/utils/op_kernel/operator/gemm/block/block_mmad.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/act/gemm/block/block_mmad.hpp" + +#include "block_mmad_preload_async_with_callback_resident_a.h" diff --git a/csrc/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h b/csrc/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h new file mode 100644 index 00000000000..09f7b067578 --- /dev/null +++ b/csrc/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h @@ -0,0 +1,420 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/act/act.hpp" +#include "../../catlass/act/arch/resource.hpp" +#include "../../catlass/act/coord.hpp" +#include "../../catlass/act/detail/callback.hpp" +#include "../../catlass/act/gemm_coord.hpp" +#include "../../catlass/act/gemm/dispatch_policy.hpp" +#include "../../catlass/act/gemm/helper.hpp" + +namespace Act::Gemm::Block { + +template +struct BlockMmad< + MmadAtlasA2PreloadAsyncWithCallbackResidentA, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = + MmadAtlasA2PreloadAsyncWithCallbackResidentA; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1A_STAGES = DispatchPolicy::L1A_STAGES; + static constexpr uint32_t L1B_STAGES = DispatchPolicy::L1B_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert(L1A_TILE_SIZE * L1A_STAGES + L1B_TILE_SIZE * L1B_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + ACT_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + ACT_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1A_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + } + for (uint32_t i = 0; i < L1B_STAGES; ++i) { + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + } + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe, + Callback const &callbackAfterFixpipe) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + bool useResidentA = + (kTileCount == L1A_STAGES) && (!isFirstLoad) && (gmBlockA.GetPhyAddr() == lastGmBlockA.GetPhyAddr()); + isFirstLoad = false; + lastGmBlockA = gmBlockA; + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = + (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1AListId]); + if (!useResidentA) { + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1AListId], gmTileA, L1A_LAYOUT, layoutTileA); + } + AscendC::SetFlag(l1AEventList[l1AListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1BListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1BListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1BListId]); + + // If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1AListId = l1AListId; + l1TileMmadParams.l1BListId = l1BListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; + l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1AListId = (l1AListId + 1 < L1A_STAGES) ? (l1AListId + 1) : 0; + l1BListId = (l1BListId + 1 < L1B_STAGES) ? (l1BListId + 1) : 0; + } + } + + ACT_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + +private: + struct L1TileMmadParams { + uint32_t l1AListId; + uint32_t l1BListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callbackBeforeFixpipe; + Callback callbackAfterFixpipe; + + ACT_DEVICE + L1TileMmadParams() = default; + }; + + ACT_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + for (uint32_t i = 0; i < L1A_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1AEventList[i] = i; + AscendC::SetFlag(l1AEventList[i]); + } + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1A_STAGES; + for (uint32_t i = 0; i < L1B_STAGES; ++i) { + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1BEventList[i] = i + L1A_STAGES; + AscendC::SetFlag(l1BEventList[i]); + } + } + + ACT_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + ACT_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + ACT_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + ACT_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1AListId]; + auto &l1BTensor = l1BTensorList[params.l1BListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1AListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1AListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1BListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1BListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && + (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + + params.callbackBeforeFixpipe(); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + + params.callbackAfterFixpipe(); + } + } + + AscendC::LocalTensor l1ATensorList[L1A_STAGES]; + AscendC::LocalTensor l1BTensorList[L1B_STAGES]; + int32_t l1AEventList[L1A_STAGES]; + int32_t l1BEventList[L1B_STAGES]; + uint32_t l1AListId{0}; + uint32_t l1BListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; + + bool isFirstLoad{true}; + AscendC::GlobalTensor lastGmBlockA; +}; + +} // namespace Act::Gemm::Block diff --git a/csrc/utils/op_kernel/operator/gemm/dispatch_policy.h b/csrc/utils/op_kernel/operator/gemm/dispatch_policy.h new file mode 100644 index 00000000000..8f2cdec034e --- /dev/null +++ b/csrc/utils/op_kernel/operator/gemm/dispatch_policy.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../catlass/act/gemm/dispatch_policy.hpp" + +namespace Act::Gemm { + +template +struct MmadAtlasA2PreloadAsyncWithCallbackResidentA : public MmadAtlasA2Async { + static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance + static constexpr uint32_t L1A_STAGES = L1A_STAGES_; + static constexpr uint32_t L1B_STAGES = L1B_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +} // namespace Act::Gemm diff --git a/csrc/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h new file mode 100644 index 00000000000..53a4dac75d8 --- /dev/null +++ b/csrc/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -0,0 +1,2023 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe operator kernel function implementation file + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 create FusedDeepMoe operator kernel function implementation file + */ +#pragma once + +#include "../../catlass/act/act.hpp" +#include "../../catlass/act/arch/cross_core_sync.hpp" +#include "../../catlass/act/arch/resource.hpp" +#include "../../catlass/act/coord.hpp" +#include "../../catlass/act/detail/callback.hpp" +#include "../../catlass/act/gemm_coord.hpp" +#include "../../catlass/act/matrix_coord.hpp" +#include "../../catlass/act/epilogue/tile/tile_swizzle.hpp" +#include "../../catlass/act/epilogue/tile/tile_copy.hpp" + +#include "../../../../../op_kernel/fused_deep_moe_base.h" + +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024; +constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t TOKEN_EXTRA_SPACE = 512; +constexpr uint32_t INT32_COUNT_PER_BLOCK = 8; +constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512; +constexpr uint32_t COMP_AIV_CORE_NUM = 24; // 24 AIV 做deq-swiglu计算,当前不支持自己调整 +constexpr uint32_t SEND_AIV_CORE_NUM = 48; // 单卡单专家时全部核发送/接收,多专家时砍半 +constexpr uint32_t RECV_AIV_CORE_NUM = 48; // 单卡单专家时全部核发送/接收,多专家时砍半 +constexpr int64_t LOOP_TMP_SIZE = 4096; // 计算地址偏移优化使用空间 +constexpr int32_t SUB_AIV_NUM = 2; // 1C配2V,即1个cube搭配两个vector +constexpr int32_t ODD_EVEN_BASE = 2; // 判断奇偶的基数 +constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t GATHER_SECOND_NUM = 2; +constexpr uint32_t MAX_QUANT_ROW_ONCE = 8; +constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // 量化使用UB不超过176KB +#define OPT_RANK_OFFSET 512 + +#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN) +#define CEIL(x, y) (((x) + (y - 1)) / (y)) +#define UB_BLOCK_SIZE (32) +#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsExp)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \ + dataState * WIN_STATE_OFFSET) +#define GET_WIND_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsIn)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \ + winDataSizeOffset + rankId * OPT_RANK_OFFSET) +#define TOKEN_FLAG_1 (0x55555555) +#define TOKEN_FLAG_2 (0x33333333) +#define V_TO_C_FLAG_1 (0x03030303) +#define V_TO_C_FLAG_2 (0x05050505) +#define AIC_STATE_SPACE_IDNEX (48) +#define AIV_STATE_SPACE_IDNEX (72) +#define CV_FLAG_INDEX 0 +#define GROUP_ID_INDEX 1 +#define PRE_COUNT_INDEX 2 +#define SELF_COUNT_INDEX 3 +#define TOTAL_COUNT_INDEX 4 +#define GROUP_TOKEN_COUNT 3 // 等于SELF_COUNT_INDEX +#define GROUP_INFO_SIZE 32 + +#define REACH_STEP_1_SEND_COUNT +#define REACH_STEP_2_SEND_TOKEN +#define REACH_STEP_3_RECV_COUNT +#define REACH_STEP_4_RECV_TOKEN +#define REACH_STEP_5_WAIT_RECV_CORE +#define REACH_STEP_6_GMM1_DEQ_SWIGLU +#define REACH_STEP_7_UPDATE_INFO +#define REACH_STEP_8_QUANT + +#define SEND_TOKEN_RETURN // 这个宏好像比较影响性能,待确认 + +namespace Act::Gemm::Kernel { + +template +class BlockQuant +{ +public: + using ElementInput = float; + using LayoutInput = layout::RowMajor; + using ElementDequantScale = float; + using LayoutDequantScale = layout::VectorLayout; + using ElementOutput = int8_t; + using LayoutOutput = layout::RowMajor; + + using InputType = GemmType; + using DequantScaleType = GemmType; + using OutputType = GemmType; + + using EpilogueTileSwizzle = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + struct Params { + __gm__ ElementInput *ptrInput{nullptr}; + LayoutInput layoutInput; + __gm__ ElementDequantScale *ptrDequantScale{nullptr}; + LayoutDequantScale layoutDequantScale; + __gm__ ElementOutput *ptrOutput{nullptr}; + LayoutOutput layoutOutput; + uint32_t tileRow; + uint32_t tileColumn; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementInput *ptrInput_, LayoutInput const &layoutInput_, + __gm__ ElementDequantScale *ptrQuantScale_, LayoutDequantScale const &layoutQuantScale_, + __gm__ ElementOutput *ptrOutput_, LayoutOutput const layoutOutput_, const uint32_t tileRow_, + const uint32_t tileColumn_) + : ptrInput(ptrInput_), + layoutInput(layoutInput_), + ptrDequantScale(ptrQuantScale_), + layoutDequantScale(layoutQuantScale_), + ptrOutput(ptrOutput_), + layoutOutput(layoutOutput_), + tileRow(tileRow_), + tileColumn(tileColumn_) + {} + }; + + ACT_DEVICE + BlockQuant(Arch::Resource const &resource, Params const ¶ms_) : params(params_) + { + int64_t ubOffset = 0; + tileRow = params_.tileRow; + tileColumn = params_.tileColumn; + tileCount = tileRow * tileColumn; + halfTileColumn = tileColumn / 2; + halfTileCount = tileRow * halfTileColumn; + + ubInput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementInput); + ubDequantScale = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tileRow * sizeof(ElementDequantScale)); + ubOutput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementOutput); + + ubAbs = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(float); + ubMax = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += halfTileCount * sizeof(float); + ubReduceMax = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tileRow * sizeof(float)); + ubQuantScale = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tileRow * sizeof(float)); + ubInputTmp = ubAbs; + ubQuantF32 = ubAbs; + ubQuantS32 = ubAbs.ReinterpretCast(); + ubQuantF16 = ubAbs.ReinterpretCast(); + + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } + + ACT_DEVICE + ~BlockQuant() + { + AscendC::WaitFlag(0); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + ACT_DEVICE + void operator()(MatrixCoord const &blockShape, MatrixCoord const &blockCoord, MatrixCoord const &actualBlockShape) + { + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmInput; + gmInput.SetGlobalBuffer(params.ptrInput); + AscendC::GlobalTensor gmDequantScale; + gmDequantScale.SetGlobalBuffer(params.ptrDequantScale); + AscendC::GlobalTensor gmOutput; + gmOutput.SetGlobalBuffer(params.ptrOutput); + + auto ubTileStride = MakeCoord(static_cast(tileColumn), 1L); + auto ubHalfTileStride = MakeCoord(static_cast(halfTileColumn), 1L); + auto tileShape = MakeCoord(tileRow, tileColumn); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileInput = gmInput[params.layoutInput.GetOffset(tileOffset)]; + auto layoutGmTileInput = params.layoutInput.GetTileLayout(actualTileShape); + + layout::RowMajor layoutUbInput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(0); + copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + AscendC::Abs(ubAbs, ubInput, tileCount); + AscendC::PipeBarrier(); + + for (uint32_t rowIdx = 0; rowIdx < tileRow; ++rowIdx) { + AscendC::Max(ubMax[rowIdx * halfTileColumn], ubAbs[rowIdx * tileColumn], + ubAbs[rowIdx * tileColumn + halfTileColumn], halfTileColumn); + } + + AscendC::PipeBarrier(); + AscendC::Muls(ubInputTmp, ubInput, 127.f, tileCount); + + constexpr uint32_t elementPerBlk = BYTE_PER_BLK / sizeof(float); + constexpr int32_t mask = 64; + + AscendC::BinaryRepeatParams maxParams; + maxParams.dstBlkStride = halfTileColumn / elementPerBlk; + maxParams.src0BlkStride = halfTileColumn / elementPerBlk; + maxParams.src1BlkStride = halfTileColumn / elementPerBlk; + maxParams.dstRepStride = 1; + maxParams.src0RepStride = 1; + maxParams.src1RepStride = 1; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(float); + uint32_t reduceWidth = halfTileColumn; + while (reduceWidth > (BLK_NUM_PER_VECTOR_FRACTAL * BYTE_PER_BLK / sizeof(float))) { + reduceWidth >>= 1; + AscendC::Max(ubMax, ubMax, ubMax[reduceWidth], mask, reduceWidth / elementPerBlk, maxParams); + AscendC::PipeBarrier(); + } + + AscendC::WholeReduceMax(ubReduceMax, ubMax, mask, tileRow, 1, 1, halfTileColumn / elementPerBlk, + AscendC::ReduceOrder::ORDER_ONLY_VALUE); + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::Muls(ubDequantScale, ubReduceMax, 1.0f / 127.0f, tileRow); + AscendC::SetFlag(0); + + auto dequantScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto dequantScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTileDequantScale = gmDequantScale[params.layoutDequantScale.GetOffset(dequantScaleTileOffset)]; + auto layoutGmTileDequantScale = params.layoutDequantScale.GetTileLayout(dequantScaleTileShape); + + auto layoutUbDequantScale = + LayoutDequantScale::template MakeLayoutInUb(dequantScaleTileShape); + + AscendC::WaitFlag(0); + copyUbToGmDequantScale(gmTileDequantScale, ubDequantScale, layoutGmTileDequantScale, layoutUbDequantScale); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + for (uint32_t rowIdx = 0; rowIdx < tileRow; ++rowIdx) { + AscendC::Muls(ubQuantF32[rowIdx * tileColumn], ubInputTmp[rowIdx * tileColumn], + 1.f / ubReduceMax.GetValue(rowIdx), tileColumn); + } + + AscendC::PipeBarrier(); + AscendC::Cast(ubQuantS32, ubQuantF32, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::PipeBarrier(); + AscendC::SetDeqScale(static_cast(1.0)); + AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(1); + AscendC::Cast(ubOutput, ubQuantF16, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::SetFlag(1); + + auto gmTileOutput = gmOutput[params.layoutOutput.GetOffset(tileOffset)]; + auto layoutGmTileOutput = params.layoutOutput.GetTileLayout(actualTileShape); + + LayoutOutput layoutUbOutput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(1); + copyUbToGmOutput(gmTileOutput, ubOutput, layoutGmTileOutput, layoutUbOutput); + AscendC::SetFlag(1); + } + } + +private: + Params params; + uint32_t tileRow; + uint32_t tileColumn; + uint32_t tileCount; + uint32_t halfTileColumn; + uint32_t halfTileCount; + + AscendC::LocalTensor ubInput; + AscendC::LocalTensor ubDequantScale; + AscendC::LocalTensor ubOutput; + + AscendC::LocalTensor ubAbs; + AscendC::LocalTensor ubMax; + AscendC::LocalTensor ubReduceMax; + AscendC::LocalTensor ubQuantScale; + AscendC::LocalTensor ubQuantScaleBrcb; + AscendC::LocalTensor ubInputTmp; + AscendC::LocalTensor ubQuantF32; + AscendC::LocalTensor ubQuantS32; + AscendC::LocalTensor ubQuantF16; + + Epilogue::Tile::CopyGm2Ub copyGmToUbInput; + Epilogue::Tile::CopyUb2Gm copyUbToGmDequantScale; + Epilogue::Tile::CopyUb2Gm copyUbToGmOutput; +}; + +__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx) +{ + // flag++,类似set flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + global.SetValue(0, value + 1); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target) +{ + // 查看flag,类似wait flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + if (value >= target) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) +{ + row = QUANT_SPACE_FACTOR / column; + row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; +} + +template +class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementDequantScale = typename BlockQuant::ElementDequantScale; + using LayoutDequantScale = typename BlockQuant::LayoutDequantScale; + using ElementOutput = typename BlockQuant::ElementOutput; + using LayoutOutput = typename BlockQuant::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + using XType = XType_; + + // Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementDequantScale *ptrDequantScale; + LayoutDequantScale layoutDequantScale; + GM_ADDR ptrWorkspace; + GM_ADDR gmX; + GM_ADDR debugGm; + GM_ADDR gmexpertIds; + + GM_ADDR gmExpandIdx; + GM_ADDR gmEpSendCount; + GM_ADDR gmResvered; + GM_ADDR gmOutputRecvCount; + + uint32_t epRankSize; + uint32_t epRankId; + uint32_t moeExpertNum; + uint32_t moeExpertNumPerRank; + uint32_t sharedExpertNum; + uint32_t sharedExpertRankNum; + uint32_t quantMode; + uint32_t globalBs; + uint32_t bs; + uint32_t topK; + uint32_t tokenLen; + // Methods + ACT_DEVICE + Params() {} + + ACT_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, + GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, + GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_, + uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, + uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, + uint32_t h) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrDequantScale(reinterpret_cast<__gm__ ElementDequantScale *>(ptrDequantScale_)), + layoutDequantScale(layoutDequantScale_), + ptrWorkspace(ptrWorkspace_), + gmX(gmX_), + debugGm(debugGm_), + gmexpertIds(gmexpertIds_), + gmExpandIdx(gmExpandIdx_), + gmEpSendCount(gmEpSendCount_), + gmOutputRecvCount(gmOutputRecvCount_), + gmResvered(gmResvered_), + epRankSize(epRankSize_), + epRankId(epRankId_), + moeExpertNum(moeExpertNum_), + moeExpertNumPerRank(moeExpertNumPerRank_), + sharedExpertNum(sharedExpertNum_), + sharedExpertRankNum(sharedExpertRankNum_), + quantMode(quantMode_), + globalBs(globalBs_), + bs(bs_), + topK(topK_), + tokenLen(h) + {} + }; + + // Methods + ACT_DEVICE + GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace() {} + + template + ACT_DEVICE void operator()(Params const ¶ms); + + template <> + ACT_DEVICE void operator()(Params const ¶ms) + { + aicIdx = AscendC::GetBlockIdx(); + subBlockNum = AscendC::GetSubBlockNum(); + aiCoreGroupNum = AscendC::GetBlockNum(); + aicNum = aiCoreGroupNum; + aicStateGlobalCoreIdx = AIC_STATE_SPACE_IDNEX + aicIdx; + moeExpertNumPerRank = params.moeExpertNumPerRank; + isShareExpert = (params.epRankId < params.sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + // 单卡单专家48发48收 + recvCoreNum = RECV_AIV_CORE_NUM; + // 单卡多专家24收24发 + if (localExpertNum > 1) { + recvCoreNum = RECV_AIV_CORE_NUM / SUB_AIV_NUM; + } + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; // 这里假设可以整除 + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + // 更新状态,影响CV交互使用的信号值 + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aicStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * aicNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + AscendC::GlobalTensor groupTokenNumStateTensor; + aicSetFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(aicNum + AscendC::GetBlockIdx())}; // AIV等待的信息在24~48 + uint32_t target = 1; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + // 等待AIV的token收齐信号后,再往下走 + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((aicIdx < startCoreIdx) ? (aicIdx + aicNum) : aicIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // 使用软同步 + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + aicWaitFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(AscendC::GetBlockIdx()), + target}; // AIC等待的信号在前24个 + target += 1; + callbackBeforeFixpipe = MakeCallback(&aicWaitFunc1); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFunc1); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * aicNum + aicIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % aicNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + target += 1; // 追平AIV多余的软同步 + --stageUsed; + } + AscendC::SyncAll(); + } + + ACT_DEVICE + void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset) + { + // 使用AIV计算发送到对端的偏移量 + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor dstExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor subExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor workLocalTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::Duplicate(dstExpIdTensor_, dstExpertId, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Sub(subExpIdTensor_, expertIdsTensor_, dstExpIdTensor_, tokenIndex); + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpFp32 = subExpIdTensor_.ReinterpretCast(); + AscendC::LocalTensor tmpoutFp32 = dstExpIdTensor_.ReinterpretCast(); + AscendC::Abs(tmpoutFp32, tmpFp32, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Mins(subExpIdTensor_, dstExpIdTensor_, 1, tokenIndex); + AscendC::PipeBarrier(); + AscendC::ReduceSum(tmpoutFp32, tmpFp32, workLocalTensor_, tokenIndex); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + int32_t curOtherExpertCnt = dstExpIdTensor_(0); + if (tokenIndex > curOtherExpertCnt) { + curExpertCnt = tokenIndex - curOtherExpertCnt; + } + } + + ACT_DEVICE + void CalAndSendTokenCount() + { + // 计算发送token的数量,并且发送出去 + uint32_t totalExpertNum = sharedExpertRankNum + moeExpertNum; + uint32_t sendCountExpertNum = totalExpertNum / sendCoreNum; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = totalExpertNum % sendCoreNum; + uint32_t startExpertId = sendCountExpertNum * sendCoreIdx; // sharedExpertRankNum, 每个aiv发送的起始rankid + if (sendCoreIdx < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendCountExpertNum += 1; + startExpertId += sendCoreIdx; + } else { + startExpertId += remainderRankNum; + } + uint32_t endExpertId = startExpertId + sendCountExpertNum; + if (startExpertId >= totalExpertNum) { + return; + } + // 计算count及偏移:开始 + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(CEIL(expertCntUp, INT32_COUNT_PER_BLOCK) * INT32_COUNT_PER_BLOCK * UB_BLOCK_SIZE); + AscendC::Duplicate(statusTensor_, (int32_t)0, + expertCntUp * INT32_COUNT_PER_BLOCK); // 先清零再赋值,清零一定要做 + if (state == 0) { + // 一次性操作256字节,也是64个int32_t,每8个数将首个设置为0x3F800000,即浮点数的1.0 + uint64_t mask[2] = {0x101010101010101, 0}; + AscendC::PipeBarrier(); + // 这里原版代码有bug,block数量不是8的倍数时,后面的尾巴没法更新 + AscendC::Duplicate(statusTensor_, 0x3F800000, mask, CEIL(expertCntUp, 8), 1, 8); + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + if (!isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + for (uint32_t curExpertId = startExpertId; curExpertId < endExpertId; ++curExpertId) { + if (curExpertId < sharedExpertRankNum) { + continue; + } + int32_t curExpertCnt = 0; + int32_t dstExpertId = curExpertId - sharedExpertRankNum; + CalExpandxIdx(dstExpertId, expertIdsCnt, curExpertCnt, ubOffset); + int32_t cntPosIndex = curExpertId * INT32_COUNT_PER_BLOCK + 1; + statusTensor_(cntPosIndex) = curExpertCnt; + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::GlobalTensor rankGMTensor; + uint32_t offset = stateOffset * epRankId; + for (uint32_t rankIndex = startExpertId; rankIndex < endExpertId; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank > 1 && (rankIndex >= sharedExpertRankNum)) { + dstRankId = ((rankIndex - sharedExpertRankNum) / moeExpertNumPerRank + sharedExpertRankNum); + offset = + (epRankId + (rankIndex - sharedExpertRankNum) % moeExpertNumPerRank * epRankSize) * stateOffset; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_STATE_ADDR_BY_RANK_ID(dstRankId) + offset); // 计算地址偏移 + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + AscendC::DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); // 8时数据大小,按32对齐拷贝 + } + } + + ACT_DEVICE + void QuantToken(AscendC::LocalTensor &xInTensor, AscendC::LocalTensor &yInt8Tensor, int64_t ubOffset) + { + // 量化token的函数,这里UB空间基本用完就释放了,所以在内部计算UB偏移 + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor xFp32TmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(tokenLength * sizeof(float)); + AscendC::LocalTensor xFp32AbsTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(tokenLength * sizeof(float)); + AscendC::LocalTensor xRowMaxTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor ytmpInt32Tensor = xFp32TmpTensor.template ReinterpretCast(); + AscendC::LocalTensor yHalfTensor = xFp32TmpTensor.template ReinterpretCast(); + AscendC::LocalTensor yFp32Tensor = yInt8Tensor.template ReinterpretCast(); + AscendC::LocalTensor yInt32Tensor = yInt8Tensor.template ReinterpretCast(); + + AscendC::Cast(xFp32TmpTensor, xInTensor, AscendC::RoundMode::CAST_NONE, tokenLength); + AscendC::PipeBarrier(); + AscendC::Abs(xFp32AbsTensor, xFp32TmpTensor, tokenLength); + AscendC::PipeBarrier(); + AscendC::ReduceMax(xRowMaxTensor, xFp32AbsTensor, xFp32AbsTensor, tokenLength, false); + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + float dynamicQuantScale = float(127.0) / xRowMaxTensor.GetValue(0); + yFp32Tensor.SetValue(tokenLength / sizeof(float), float(1.0) / dynamicQuantScale); + yInt32Tensor.SetValue(tokenLength / sizeof(int32_t) + 1, tokenFlag); + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::Muls(xFp32TmpTensor, xFp32TmpTensor, dynamicQuantScale, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(ytmpInt32Tensor, xFp32TmpTensor, AscendC::RoundMode::CAST_RINT, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(yHalfTensor, ytmpInt32Tensor, AscendC::RoundMode::CAST_ROUND, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(yInt8Tensor, yHalfTensor, AscendC::RoundMode::CAST_TRUNC, tokenLength); + } + + ACT_DEVICE + void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale) + { + // 给共享专家发送token + uint32_t newAivId = sendCoreIdx - sendToMoeAivNum; + uint32_t sendTokenNum = axisBS / sendToShareAivNum; // 每个aiv需要发送的token数 + uint32_t remainderTokenNum = axisBS % sendToShareAivNum; // 余数 + uint32_t startTokenId = sendTokenNum * newAivId; // 每个aiv发送时的起始rankid + if (newAivId < remainderTokenNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; +#ifdef SEND_TOKEN_RETURN + if (startTokenId >= axisBS) { + return; + } +#endif + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor yInt8Tensor[BUFFER_NUM]; + AscendC::LocalTensor yFp32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; // token输入 + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + yInt8Tensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + yFp32Tensor[0] = yInt8Tensor[0].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + yInt8Tensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + yFp32Tensor[1] = yInt8Tensor[1].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::GlobalTensor dstWinGMTensor; // token输出 + AscendC::GlobalTensor expandXOutGlobal; + expandXOutGlobal.SetGlobalBuffer((__gm__ int8_t *)(gmX1)); + AscendC::GlobalTensor dynamicScalesOutGMTensor_; + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)(gmX1Scale)); +#ifndef SEND_TOKEN_RETURN + if (startTokenId < axisBS) { +#endif + // 输入输出开double buffer + AscendC::SetFlag(0); // MTE2等MTE3 + AscendC::SetFlag(1); // MTE2等MTE3 + AscendC::SetFlag(0); + AscendC::SetFlag(1); + + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + uint32_t index = (tokenIndex & 1) ? 0 : 1; + int32_t eventId = (tokenIndex & 1) ? 0 : 1; + // 下面的计算有点绕,目的是计算目的专家卡和偏移 + uint32_t temp = (epRankId * axisBS) / sharedExpertRankNum; + // 当前token发给哪个共享专家 + uint32_t moeOnShareRank = CEIL((tokenIndex + 1 + temp) * sharedExpertRankNum, axisBS) - 1 - epRankId; + // 发给该共享专家已经有多少token数据 + uint32_t preCnt = (moeOnShareRank + epRankId) * axisBS / sharedExpertRankNum - + epRankId * axisBS / sharedExpertRankNum; + dstWinGMTensor.SetGlobalBuffer( + (__gm__ int8_t *)(GET_WIND_ADDR_BY_RANK_ID(moeOnShareRank) + expertPerSizeOnWin * epRankId)); + + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + QuantToken(xInTensor[index], yInt8Tensor[index], ubOffset); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(0); + + AscendC::WaitFlag(eventId); + if (isShareExpert) { + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::DataCopy(expandXOutGlobal[tokenIndex * tokenLength], yInt8Tensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(dynamicScalesOutGMTensor_[tokenIndex], + yFp32Tensor[index][tokenLength / sizeof(float)], dataCopyParamsFloat); + } else { + // 怀疑有时序问题,所以分开发送 + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu], yInt8Tensor[index], + tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu + tokenLength], + yInt8Tensor[index][tokenLength], scaleParamPad); + } + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); +#ifndef SEND_TOKEN_RETURN + } +#endif + } + + ACT_DEVICE + void SendToMoeExprt(GM_ADDR gmX, GM_ADDR gmExpandIdx) + { + // 给路由专家发送token + uint32_t sendTokenNum = expertIdsCnt / sendToMoeAivNum; + uint32_t remainderTokenNum = expertIdsCnt % sendToMoeAivNum; + uint32_t startTokenId = sendTokenNum * sendCoreIdx; + if (sendCoreIdx < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += sendCoreIdx; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; +#ifdef SEND_TOKEN_RETURN + if (startTokenId >= expertIdsCnt) { + return; + } +#else + if (startTokenId < expertIdsCnt) { +#endif + AscendC::LocalTensor expertCountTensor = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + AscendC::Duplicate(expertCountTensor, (int32_t)0, expertIdsCnt); // 清零 + AscendC::SetFlag(1); + AscendC::WaitFlag(1); + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor yInt8Tensor[BUFFER_NUM]; + AscendC::LocalTensor yFp32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; // token输入 + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + yInt8Tensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + yInt8Tensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::GlobalTensor dstWinGMTensor; // token输出 + // 输入输出开double buffer + AscendC::SetFlag(0); // MTE2等MTE3 + AscendC::SetFlag(1); // MTE2等MTE3 + AscendC::SetFlag(0); + AscendC::SetFlag(1); + uint32_t sendValidTokenIndex = 0; + for (uint32_t sendGroupIndex = 0; sendGroupIndex < moeExpertNumPerRank; ++sendGroupIndex) { + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + int32_t dstExpertId = expertIdsTensor_(tokenIndex); + if (dstExpertId < 0) { + continue; + } + if ((dstExpertId % moeExpertNumPerRank) != sendGroupIndex) { // 优先发送指定专家的token + continue; + } + uint32_t index = (sendValidTokenIndex & 1) ? 0 : 1; + int32_t eventId = (sendValidTokenIndex & 1) ? 0 : 1; + sendValidTokenIndex += 1; + int32_t curExpertCnt = 0; + CalExpandxIdx(dstExpertId, tokenIndex, curExpertCnt, ubOffset); + expertCountTensor(tokenIndex - startTokenId) = curExpertCnt; + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank + sharedExpertRankNum; + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(tempRankId) + + (expertPerSizeOnWin * (epRankId * moeExpertNumPerRank + + dstExpertId % moeExpertNumPerRank)) + + hCommuSize * curExpertCnt); + dstWinGMTensor.SetGlobalBuffer((__gm__ int8_t *)rankGM); + + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex / axisK * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + QuantToken(xInTensor[index], yInt8Tensor[index], ubOffset); + AscendC::SetFlag(eventId); + + AscendC::WaitFlag(0); + AscendC::WaitFlag(eventId); + + // 担心有时序问题,所以分开发送 + AscendC::DataCopy(dstWinGMTensor, yInt8Tensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[tokenLength], yInt8Tensor[index][tokenLength], scaleParamPad); + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(0); // MTE2等MTE3 + AscendC::WaitFlag(1); // MTE2等MTE3 + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + + AscendC::GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)gmExpandIdx + startTokenId); + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, + 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(expandIdxGMTensor, expertCountTensor, expertIdsCntParams); +#ifndef SEND_TOKEN_RETURN + } +#endif +} + +ACT_DEVICE void +SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx) +{ + ubOffset = 0; + expertIdsCnt = axisBS * axisK; + + AscendC::GlobalTensor expertIdsGMTensor_; + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds); + expertIdsTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, + 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + CalAndSendTokenCount(); + AscendC::PipeBarrier(); + if (hasShareExpert) { + sendToShareAivNum = sendCoreNum / (axisK + 1); // 均等分,取整 + if (sendToShareAivNum == 0) { + sendToShareAivNum = 1; + } + } + sendToMoeAivNum = sendCoreNum - sendToShareAivNum; + + AscendC::SetDeqScale((half)1.000000e+00f); + if (hasShareExpert && sendCoreIdx >= sendToMoeAivNum) { + SendToShareExprt(gmX, gmX1, gmX1Scale); + } else { + SendToMoeExprt(gmX, gmExpandIdx); + } + AscendC::PipeBarrier(); +} + +ACT_DEVICE +void RecvCount(int64_t ubOffset) +{ + // 接收count数据 + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + uint32_t startStatusIndex = 0; // 目前每个核都要收集所有的count + + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor sumTmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + gatherTmpTensor.SetValue(0, 1); + + uint32_t mask = 1; // gatherMask + sum 相关参数 + uint64_t rsvdCnt = 0; + AscendC::SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget * recStatusNumPerCore) + (float)0.5; + AscendC::DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, static_cast(15), + 0}; // srcStride为15个block + AscendC::GlobalTensor windowInstatusFp32Tensor_; + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId)); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + uint32_t preRecvTokenCount = 0; + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset / sizeof(float)], + intriParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumTmpTensor, sumParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } +} + +ACT_DEVICE +void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset, GM_ADDR gmOutputRecvCount) +{ + // 计算前缀和,目的是知道自己收到的token在output中的偏移 + int64_t subUbOffset = ubOffset; + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + if (isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + uint64_t rsvdCnt = 0; + gatherTmpTensor.SetValue(0, GATHER_SECOND_NUM); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + if (isRecvCore && recvCoreIdx == 0) { + AscendC::GlobalTensor recvCountTensor; + recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount); + AscendC::DataCopyExtParams dataCopyParams = { + 1U, static_cast(localExpertNum * epRankSize * sizeof(int32_t)), 0U, 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(recvCountTensor, gatherMaskOutTensor.ReinterpretCast(), dataCopyParams); + } + // 这里是为ReduceSum准备所需空间,本应该计算好需要多大空间,但当前是给偏移,且用完就释放,所以就不计算了 + AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + AscendC::PipeBarrier(); + AscendC::ReduceSum(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor, + (startRankId + 1) <= recvExpertNum ? (startRankId + 1) : recvExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); +} + +ACT_DEVICE +void RecvToken(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, uint32_t &coreTokenCount, uint32_t startRankId, + uint32_t endRankId, uint32_t recvRankNumPerCore, int64_t ubOffset) +{ + // 接收token + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::LocalTensor xTmpTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::LocalTensor xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutCountTensor = (gatherMaskOutTensor.template ReinterpretCast()); + AscendC::GlobalTensor tokGlobal; + AscendC::GlobalTensor tokGlobalInt32; + AscendC::GlobalTensor expandXOutGlobal; + AscendC::GlobalTensor dynamicScalesOutGMTensor_; + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)(gmX1Scale)); + uint32_t beginIdx = 0; + for (uint32_t index = startRankId; index < endRankId; index++) { + uint32_t i = index - startRankId; + if (i > 0) { + gatherMaskOutCountTensor.SetValue( + i, gatherMaskOutCountTensor.GetValue(i - 1) + gatherMaskOutCountTensor.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_COUNT_PER_BLOCK + 1); + coreTokenCount += count; + beginIdx = gatherMaskOutCountTensor.GetValue(i) - count; + if (isShareExpert && index < sharedExpertRankNum) { + beginIdx += count; + continue; + } + uint32_t winOffset = index; + if (!isShareExpert && moeExpertNumPerRank > 1) { + // count的空间排布,与token数据的空间排布不同,需要转换成数据区的排布偏移 + // srcRank: index % epRankSize + // localExpertId: index / epRankSize + // Addr: (srcRank * moeExpertNumPerRank + localExpertId) * expertPerSizeOnWin + winOffset = (index % epRankSize) * moeExpertNumPerRank + index / epRankSize; + } + GM_ADDR wAddr = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(epRankId)) + winOffset * expertPerSizeOnWin; + AscendC::SetFlag(0); + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ int8_t *)(wAddr + j * hCommuSize)); + tokGlobalInt32.SetGlobalBuffer((__gm__ int32_t *)(wAddr + j * hCommuSize + hOutSize)); + expandXOutGlobal.SetGlobalBuffer((__gm__ int8_t *)(gmX1) + (beginIdx + j) * tokenLength, tokenLength); + + while (true) { + AscendC::DataCopy(tmpLocalTensor, tokGlobalInt32, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + if (tmpLocalTensor.GetValue(1) == tokenFlag) { + tokGlobalInt32.SetValue(1, 0); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(tokGlobalInt32[1]); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::DataCopy(xTmpTensor_, tokGlobal, axisHCommu); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(dynamicScalesOutGMTensor_[beginIdx + j], xOutFp32Tensor_[tokenLength / sizeof(float)], + dataCopyParamsFloat); + AscendC::DataCopy(expandXOutGlobal, xTmpTensor_, tokenLength); + AscendC::SetFlag(0); + } + AscendC::WaitFlag(0); + beginIdx += count; + } + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyExtParams dataCopyOutParams = {1U, static_cast(recvRankNumPerCore * sizeof(int32_t)), 0U, + 0U, 0U}; + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::DataCopyPad(sendCountsGlobal[startRankId], gatherMaskOutCountTensor, dataCopyOutParams); +} + +ACT_DEVICE +void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, GM_ADDR gmOutputRecvCount) +{ + ubOffset = 0; + RecvCount(ubOffset); + + // 先按本地专家分核,再在专家内进一步分核 + uint32_t recvExpertNum = isShareExpert ? epRankSize : expertCntUp; + uint32_t recvCoreNumPerGroup = recvCoreNum / localExpertNum; // 每个group由若干核处理,这里先假定可以整除且不为0 + uint32_t recvRankNumPerCore = epRankSize / recvCoreNumPerGroup; // 每个核处理的rank数量 + uint32_t remainderRankNum = epRankSize % recvCoreNumPerGroup; + + uint32_t groupId = recvCoreIdx / recvCoreNumPerGroup; // 当前核处理的是哪个group + uint32_t recvCoreIdxInGroup = recvCoreIdx % recvCoreNumPerGroup; // 当前核处理的是group中第几个 + uint32_t startRankIdInGroup = recvRankNumPerCore * recvCoreIdxInGroup; // 当前核处理的起始rank + if (recvCoreIdxInGroup < remainderRankNum) { + recvRankNumPerCore += 1; + startRankIdInGroup += recvCoreIdxInGroup; + } else { + startRankIdInGroup += remainderRankNum; + } + uint32_t endRankIdInGroup = startRankIdInGroup + recvRankNumPerCore; + uint32_t startRankId = epRankSize * groupId + startRankIdInGroup; + uint32_t endRankId = epRankSize * groupId + endRankIdInGroup; + + uint32_t coreTokenCount = 0; + + if (startRankId < recvExpertNum) { + // 计算前缀和,以及接收token。这里有隐含约束,下面两个函数与RecvCount的ubOffset入参应保持一致,这样才能拿到有效数据 + GetCumSum(startRankId, recvExpertNum, ubOffset, gmOutputRecvCount); + RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); + } + + // 接收完成,通过写GM告知C核和计算V核 + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(0); + ubOffset += CEIL_UP(UB_BLOCK_SIZE); + tmpLocalTensor.SetValue(CV_FLAG_INDEX, vToCFlag); + tmpLocalTensor.SetValue(GROUP_ID_INDEX, groupId); + tmpLocalTensor.SetValue(SELF_COUNT_INDEX, coreTokenCount); + AscendC::SetFlag(0); + + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::WaitFlag(0); + AscendC::SetAtomicAdd(); + // 用原子加,各个核收到的token数量加一起,就是专家收到的token数量 + AscendC::DataCopy(groupTokenNumStateTensor[groupId * GROUP_INFO_SIZE], tmpLocalTensor, INT32_COUNT_PER_BLOCK); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); +} + +ACT_DEVICE +void CompCoreFunc(GM_ADDR gmCVSwapBuff, __gm__ ElementScale *gmScale, __gm__ ElementPerTokenScale *gmTokenScale, + __gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale, + LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput) +{ + uint32_t nOut = n / 2; + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; // 这里假设可以整除 + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(gmCVSwapBuff)); + auto layoutC = layout::RowMajor{L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES, L1TileShape::N}; + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t target = 1; + uint32_t startCoreIdx = 0; + + AscendC::GlobalTensor groupTokenNumStateTensor; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + // 流程与C核类似,等专家token数据,以及计算、软同步 + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, n, k}; + LayoutPerTokenScale layoutPerTokenScale = + wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, + layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, + layoutD}; + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = + ((compCoreIdx < startCoreIdx) ? (compCoreIdx + aiCoreGroupNum) : compCoreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aiCoreGroupNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * aiCoreGroupNum + aiCoreGroupIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + CheckSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(COMP_AIV_CORE_NUM + compCoreIdx), target); // AIV等待的信号在24~48 + target += 1; + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + EncreaseSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(compCoreIdx)); + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * nOut; + + startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum; + } + } + // 清理软同步残留信息,避免影响别处或者下次运行 + AscendC::PipeBarrier(); + AscendC::GlobalTensor softSyncTensor; + softSyncTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + SOFT_SYNC_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(softSyncTensor[compCoreIdx * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], tmpZeroLocalTensor, + INT32_COUNT_PER_BLOCK); + AscendC::DataCopy(softSyncTensor[(compCoreIdx + COMP_AIV_CORE_NUM) * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], + tmpZeroLocalTensor, INT32_COUNT_PER_BLOCK); +} + +ACT_DEVICE +void AivInitParams(Params const ¶ms) +{ + aiCoreGroupNum = AscendC::GetBlockNum(); + subBlockNum = AscendC::GetSubBlockNum(); + aivIdx = AscendC::GetBlockIdx(); + aiCoreGroupIdx = aivIdx / subBlockNum; + aivStateGlobalCoreIdx = AIV_STATE_SPACE_IDNEX + aivIdx; + + isCompCore = (aivIdx % SUB_AIV_NUM) == 0; // 偶数核做计算 + compCoreNum = COMP_AIV_CORE_NUM; + compCoreIdx = aiCoreGroupIdx; + // 单卡单专家48发48收 + isRecvCore = true; + isSendCore = true; + recvCoreIdx = aivIdx; + sendCoreIdx = aivIdx; + sendCoreNum = SEND_AIV_CORE_NUM; + recvCoreNum = RECV_AIV_CORE_NUM; + + moeExpertNumPerRank = params.moeExpertNumPerRank; + + epRankSize = params.epRankSize; + epRankId = params.epRankId; + expertCntUp = epRankSize * moeExpertNumPerRank; + sharedExpertRankNum = params.sharedExpertRankNum; + hasShareExpert = (sharedExpertRankNum > 0); + isShareExpert = (epRankId < sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + moeExpertNum = params.moeExpertNum; + tokenLength = params.tokenLen; + + // 单卡多专家改为24收24发 + if (localExpertNum > 1) { + isRecvCore = ((aivIdx % ODD_EVEN_BASE) == 0); // 偶数核接收 + isSendCore = ((aivIdx % ODD_EVEN_BASE) == 1); // 基数核发送 + recvCoreIdx = aivIdx / SUB_AIV_NUM; + sendCoreIdx = aivIdx / SUB_AIV_NUM; + sendCoreNum = SEND_AIV_CORE_NUM / SUB_AIV_NUM; + recvCoreNum = RECV_AIV_CORE_NUM / SUB_AIV_NUM; + } + + hOutSize = tokenLength * sizeof(int8_t); + scaleParamPad = TOKEN_EXTRA_SPACE; // 预留512B给量化参数,实际只使用了4B(fp32) + hCommuSize = hOutSize + scaleParamPad; + axisHCommu = hCommuSize / sizeof(int8_t); + axisBS = params.bs; + axisK = params.topK; + uint32_t maxAxisBs = params.globalBs / epRankSize; + + stateOffset = STATE_OFFSET; + expertPerSizeOnWin = maxAxisBs * tokenLength * sizeof(XType); + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); +} + +ACT_DEVICE +void AivInitState() +{ + // 核状态更新,决定使用哪一半空间,以及各种信号的切换 + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + dataState = selfDataStatusTensor(aivIdx * UB_ALIGN); + if (dataState == 0) { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + // 专家token数据信号 + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + AscendC::PipeBarrier(); + winDataSizeOffset = dataState * epRankSize * expertPerSizeOnWin * moeExpertNumPerRank; + GM_ADDR statusSpaceGm_ = GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId); + AscendC::GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + state = selfStatusTensor(aivIdx * UB_ALIGN); + if (state == 0) { + sumTarget = (float)1.0; + tokenFlag = TOKEN_FLAG_1; + selfStatusTensor(aivIdx * UB_ALIGN) = 0x3F800000; // 浮点数的1.0 + } else { + sumTarget = 0.0; + tokenFlag = TOKEN_FLAG_2; + selfStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); +} + +ACT_DEVICE +void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount) +{ + if (aivIdx == aiCoreGroupNum * subBlockNum - 1) { + // 清理专家token数量信息 + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, GROUP_INFO_SIZE * localExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(groupTokenNumStateTensor, tmpZeroLocalTensor, GROUP_INFO_SIZE * localExpertNum); + } + + if (isRecvCore && recvCoreIdx == (recvCoreNum - 1)) { + // 更新group_list信息 + AscendC::GlobalTensor expertTokenNumsOutGMTensor_; + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList)); + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]); + __asm__ __volatile__(""); + uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1); + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + __asm__ __volatile__(""); + } + } +} + +template <> +ACT_DEVICE void operator()(Params const ¶ms) +{ + AivInitParams(params); + AivInitState(); + if (isSendCore) { + SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA, + (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx); + } + if (isRecvCore) { + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount, + (GM_ADDR)params.gmOutputRecvCount); + } + + auto gmSwigluOutput = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES * L1TileShape::N)); + if (isCompCore) { + CompCoreFunc(params.ptrWorkspace, params.ptrScale, params.ptrPerTokenScale, gmSwigluOutput, + params.problemShape.n(), params.problemShape.k(), params.layoutScale, params.layoutPerTokenScale, + params.layoutOutput); + } + + icache_preload(8); + AscendC::SyncAll(); + AscendC::PipeBarrier(); + + UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount); + { + // 量化计算 + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.gmEpSendCount)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(sendCountsGlobal); + __asm__ __volatile__(""); + totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1); + AscendC::PipeBarrier(); + uint32_t nOut = params.problemShape.n() / 2; + uint32_t quantRowOnce = 0; + CalQuantRow(nOut, quantRowOnce); + typename BlockQuant::Params quantParams{ + gmSwigluOutput, params.layoutOutput, params.ptrDequantScale, params.layoutDequantScale, + params.ptrOutput, params.layoutOutput, quantRowOnce, nOut}; + + BlockQuant blockQuant(resource, quantParams); + MatrixCoord quantShape(totalTokenCount, nOut); + MatrixCoord quantBlockShape((uint16_t)(subBlockNum * quantRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle quantSwizzle(quantShape, quantBlockShape); + for (uint32_t loopIdx = aiCoreGroupIdx; loopIdx < quantSwizzle.GetLoops(); loopIdx += aiCoreGroupNum) { + auto blockCoord = quantSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = quantSwizzle.GetActualTileShape(blockCoord); + blockQuant(quantBlockShape, blockCoord, actualBlockShape); + } + } +} + +private: +friend struct AicWaitFunc1; +friend struct AicSetFunc1; + +struct AicWaitFunc1 { + ACT_DEVICE + AicWaitFunc1() = default; + + ACT_DEVICE + void operator()() const + { + CheckSyncFlag(flagAddr, idx, target); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + uint32_t target; +}; + +struct AicSetFunc1 { + ACT_DEVICE + AicSetFunc1() = default; + + ACT_DEVICE + void operator()() const + { + EncreaseSyncFlag(flagAddr, idx); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; +}; + +AicWaitFunc1 aicWaitFunc1; +AicSetFunc1 aicSetFunc1; +Arch::Resource resource; + +AscendC::LocalTensor expertIdsTensor_; + +// 卡与专家相关 +uint32_t epRankSize{0}; +uint32_t epRankId{0}; +bool hasShareExpert{false}; +bool isShareExpert{false}; +uint32_t expertCntUp{0}; +uint32_t localExpertNum{0}; +uint32_t sharedExpertRankNum{0}; +uint32_t moeExpertNumPerRank{0}; +uint32_t moeExpertNum{0}; + +// token相关 +uint32_t hOutSize{0}; +uint32_t scaleParamPad{0}; +uint32_t hCommuSize{0}; +uint32_t axisHCommu{0}; +uint32_t axisBS{0}; +uint32_t axisK{0}; +uint32_t totalTokenCount{0}; +uint32_t expertIdsCnt{0}; +uint32_t tokenLength{0}; + +// 状态相关 +int32_t tokenFlag{0}; // token到达的flag +int32_t vToCFlag{0}; // V通知C的flag +int32_t dataState{0}; // 当前核的状态,与combine配合 +int32_t cvDataState{0}; // 当前核的状态,CV配合 +int32_t state{0}; // count的flag选择依据 +float sumTarget{0.0}; // count达到的数量 + +// 共享内存相关 +__gm__ HcclOpResParam *winContext_; +GM_ADDR statusDataSpaceGm; +uint32_t stateOffset{0}; +uint64_t expertPerSizeOnWin{0}; +uint64_t winDataSizeOffset{0}; + +// 核上资源相关 +int64_t ubOffset; + +// 分核相关 +bool isSendCore{false}; +bool isRecvCore{false}; +bool isCompCore{false}; // 参与计算deq_swiglu +uint32_t aiCoreGroupNum{0}; +uint32_t aiCoreGroupIdx{0}; +uint32_t subBlockNum{0}; +uint32_t aicNum{0}; +uint32_t sendCoreNum{0}; +uint32_t recvCoreNum{0}; +uint32_t compCoreNum{0}; +uint32_t aivIdx{0}; +uint32_t aicIdx{0}; +uint32_t sendCoreIdx{0}; +uint32_t recvCoreIdx{0}; +uint32_t compCoreIdx{0}; +uint32_t aivStateGlobalCoreIdx{0}; +uint32_t aicStateGlobalCoreIdx{0}; +uint32_t sendToMoeAivNum{0}; +uint32_t sendToShareAivNum{0}; +}; + +} // namespace Act::Gemm::Kernel + +namespace Act::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementDequantScale = typename BlockQuant::ElementDequantScale; + using LayoutDequantScale = typename BlockQuant::LayoutDequantScale; + using ElementOutput = typename BlockQuant::ElementOutput; + using LayoutOutput = typename BlockQuant::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementDequantScale *ptrDequantScale; + LayoutDequantScale layoutDequantScale; + GM_ADDR ptrWorkspace; + + // Methods + ACT_DEVICE + Params() {} + + ACT_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrDequantScale(reinterpret_cast<__gm__ ElementDequantScale *>(ptrDequantScale_)), + layoutDequantScale(layoutDequantScale_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + // Methods + ACT_DEVICE + GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + ACT_DEVICE void operator()(Params const ¶ms); + + template <> + ACT_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + ACT_DEVICE void operator()(Params const ¶ms) + { + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + auto ptrD = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N)); + + uint32_t mActual = groupList.GetValue(params.problemCount - 1); + uint32_t nOut = params.problemShape.n() / 2; + + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * nOut; + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + { + uint32_t quantRowOnce = 0; + CalQuantRow(nOut, quantRowOnce); + typename BlockQuant::Params quantParams{ptrD, + params.layoutOutput, + params.ptrDequantScale, + params.layoutDequantScale, + params.ptrOutput, + params.layoutOutput, + quantRowOnce, + nOut}; + + BlockQuant blockQuant(resource, quantParams); + MatrixCoord quantShape(mActual, nOut); + MatrixCoord quantBlockShape((uint16_t)(AscendC::GetSubBlockNum() * quantRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle quantSwizzle(quantShape, quantBlockShape); + for (uint32_t loopIdx = coreIdx; loopIdx < quantSwizzle.GetLoops(); loopIdx += coreNum) { + auto blockCoord = quantSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = quantSwizzle.GetActualTileShape(blockCoord); + + blockQuant(quantBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + ACT_DEVICE + AicWaitFunc() = default; + + ACT_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + ACT_DEVICE + AicSetFunc() = default; + + ACT_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Act::Gemm::Kernel From 953251b27297654ecb2b88b249a4eb0215b4efd8 Mon Sep 17 00:00:00 2001 From: GuoRen868 <1269192170@qq.com> Date: Mon, 10 Nov 2025 11:38:45 +0800 Subject: [PATCH 2/7] binding --- csrc/pytorch_npu_helper.hpp | 557 ++++++++++++++++++++++++++++++++++++ csrc/torch_binding.cpp | 52 ++++ csrc/torch_binding_meta.cpp | 27 ++ 3 files changed, 636 insertions(+) create mode 100644 csrc/pytorch_npu_helper.hpp diff --git a/csrc/pytorch_npu_helper.hpp b/csrc/pytorch_npu_helper.hpp new file mode 100644 index 00000000000..5ce9725d234 --- /dev/null +++ b/csrc/pytorch_npu_helper.hpp @@ -0,0 +1,557 @@ +#ifndef PYTORCH_NPU_HELPER_HPP_ +#define PYTORCH_NPU_HELPER_HPP_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/interface/EnvVariables.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" + +#define NPU_NAME_SPACE at_npu::native + +#define __FILENAME__ (strrchr("/" __FILE__, '/') + 1) + +typedef struct aclOpExecutor aclOpExecutor; +typedef struct aclTensor aclTensor; +typedef struct aclScalar aclScalar; +typedef struct aclIntArray aclIntArray; +typedef struct aclFloatArray aclFloatArray; +typedef struct aclBoolArray aclBoolArray; +typedef struct aclTensorList aclTensorList; + +typedef aclTensor *(*_aclCreateTensor)(const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type, + const int64_t *stride, int64_t offset, aclFormat format, + const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data); +typedef aclScalar *(*_aclCreateScalar)(void *value, aclDataType data_type); +typedef aclIntArray *(*_aclCreateIntArray)(const int64_t *value, uint64_t size); +typedef aclFloatArray *(*_aclCreateFloatArray)(const float *value, uint64_t size); +typedef aclBoolArray *(*_aclCreateBoolArray)(const bool *value, uint64_t size); +typedef aclTensorList *(*_aclCreateTensorList)(const aclTensor *const *value, uint64_t size); + +typedef int (*_aclDestroyTensor)(const aclTensor *tensor); +typedef int (*_aclDestroyScalar)(const aclScalar *scalar); +typedef int (*_aclDestroyIntArray)(const aclIntArray *array); +typedef int (*_aclDestroyFloatArray)(const aclFloatArray *array); +typedef int (*_aclDestroyBoolArray)(const aclBoolArray *array); +typedef int (*_aclDestroyTensorList)(const aclTensorList *array); + +constexpr int kHashBufSize = 8192; +constexpr int kHashBufMaxSize = kHashBufSize + 1024; +extern thread_local char g_hashBuf[kHashBufSize]; +extern thread_local int g_hashOffset; + +#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ + _(at::ScalarType::Byte, ACL_UINT8) \ + _(at::ScalarType::Char, ACL_INT8) \ + _(at::ScalarType::Short, ACL_INT16) \ + _(at::ScalarType::Int, ACL_INT32) \ + _(at::ScalarType::Long, ACL_INT64) \ + _(at::ScalarType::Half, ACL_FLOAT16) \ + _(at::ScalarType::Float, ACL_FLOAT) \ + _(at::ScalarType::Double, ACL_DOUBLE) \ + _(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \ + _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ + _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ + _(at::ScalarType::Bool, ACL_BOOL) \ + _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ + _(at::ScalarType::BFloat16, ACL_BF16) \ + _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ + _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) + +constexpr aclDataType kATenScalarTypeToAclDataTypeTable[static_cast(at::ScalarType::NumOptions) + 1] = { +#define DEFINE_ENUM(_1, n) n, + AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) + +#define MEMCPY_TO_BUF(data_expression, size_expression) \ + if (g_hashOffset + (size_expression) > kHashBufSize) { \ + g_hashOffset = kHashBufMaxSize; \ + return; \ + } \ + memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \ + g_hashOffset += size_expression; + +inline const char *GetOpApiLibName(void) +{ + return "libopapi.so"; +} + +inline const char *GetCustOpApiLibName(void) +{ + return "libcust_opapi.so"; +} + +inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName, const char *apiName) +{ + auto funcAddr = dlsym(handler, apiName); + if (funcAddr == nullptr) { + ASCEND_LOGW("dlsym %s from %s failed, error:%s.", apiName, libName, dlerror()); + } + return funcAddr; +} + +inline void *GetOpApiLibHandler(const char *libName) +{ + auto handler = dlopen(libName, RTLD_LAZY); + if (handler == nullptr) { + ASCEND_LOGW("dlopen %s failed, error:%s.", libName, dlerror()); + } + return handler; +} + +inline void *GetOpApiFuncAddr(const char *apiName) +{ + static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName()); + if (custOpApiHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + + static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName()); + if (opApiHandler == nullptr) { + return nullptr; + } + return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName); +} + +inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) +{ + c10::Scalar expScalar; + const at::Tensor *aclInput = &tensor; + if (aclInput->scalar_type() == at::ScalarType::Double) { + double value = *(double *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Long) { + int64_t value = *(int64_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Float) { + float value = *(float *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Int) { + int value = *(int *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Half) { + c10::Half value = *(c10::Half *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Bool) { + int8_t value = *(int8_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::BFloat16) { + c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } + return expScalar; +} + +inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) +{ + at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); + int deviceIndex = 0; + return cpuPinMemTensor.to(c10::Device(torch_npu::utils::get_npu_device_type(), deviceIndex), + cpuPinMemTensor.scalar_type(), true, true); +} + +inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type) +{ + return CopyTensorHostToDevice(scalar_to_tensor(cpu_scalar).to(scalar_data_type)); +} + +inline aclTensor *ConvertType(const at::Tensor &at_tensor) +{ + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + if (aclCreateTensor == nullptr) { + return nullptr; + } + + if (!at_tensor.defined()) { + return nullptr; + } + at::ScalarType scalar_data_type = at_tensor.scalar_type(); + aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK(acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + c10::SmallVector storageDims; + // if acl_data_type is ACL_STRING, storageDims is empty. + auto itemsize = at_tensor.itemsize(); + if (itemsize == 0) { + AT_ERROR("When ConvertType, tensor item size of cannot be zero."); + return nullptr; + } + if (acl_data_type != ACL_STRING) { + storageDims.push_back(at_tensor.storage().nbytes() / itemsize); + } + + const auto dimNum = at_tensor.sizes().size(); + aclFormat format = ACL_FORMAT_ND; + switch (dimNum) { + case 3: + // 适配matmul_allreduce_add_rmsnorm.py算子入参 + format = ACL_FORMAT_ND; + break; + case 4: + format = ACL_FORMAT_NCHW; + break; + case 5: + format = ACL_FORMAT_NCDHW; + break; + default: + format = ACL_FORMAT_ND; + } + // 适配fused_deep_moe算子的weight入参 + if (acl_data_type == ACL_INT8 && dimNum == 3) { + format = ACL_FORMAT_FRACTAL_NZ; + } + + if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + c10::Scalar expScalar = ConvertTensorToScalar(at_tensor); + at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type); + return aclCreateTensor(aclInput.sizes().data(), aclInput.sizes().size(), acl_data_type, + aclInput.strides().data(), aclInput.storage_offset(), format, storageDims.data(), + storageDims.size(), const_cast(aclInput.storage().data())); + } + + auto acl_tensor = + aclCreateTensor(at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type, at_tensor.strides().data(), + at_tensor.storage_offset(), format, storageDims.data(), storageDims.size(), + const_cast(at_tensor.storage().data())); + return acl_tensor; +} + +inline aclScalar *ConvertType(const at::Scalar &at_scalar) +{ + static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar); + if (aclCreateScalar == nullptr) { + return nullptr; + } + + at::ScalarType scalar_data_type = at_scalar.type(); + aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK(acl_data_type != ACL_DT_UNDEFINED, + std::string(c10::toString(scalar_data_type)) + " has not been supported") + aclScalar *acl_scalar = nullptr; + switch (scalar_data_type) { + case at::ScalarType::Double: { + double value = at_scalar.toDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Long: { + int64_t value = at_scalar.toLong(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Bool: { + bool value = at_scalar.toBool(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::ComplexDouble: { + auto value = at_scalar.toComplexDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + default: + acl_scalar = nullptr; + break; + } + return acl_scalar; +} + +inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) +{ + static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray); + if (aclCreateIntArray == nullptr) { + return nullptr; + } + auto array = aclCreateIntArray(at_array.data(), at_array.size()); + return array; +} + +template inline aclBoolArray *ConvertType(const std::array &value) +{ + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclBoolArray *ConvertType(const at::ArrayRef &value) +{ + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list) +{ + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + if (aclCreateTensorList == nullptr) { + return nullptr; + } + + std::vector tensor_list(at_tensor_list.size()); + for (size_t i = 0; i < at_tensor_list.size(); i++) { + tensor_list[i] = ConvertType(at_tensor_list[i]); + } + auto acl_tensor_list = aclCreateTensorList(tensor_list.data(), tensor_list.size()); + return acl_tensor_list; +} + +inline aclTensor *ConvertType(const c10::optional &opt_tensor) +{ + if (opt_tensor.has_value() && opt_tensor.value().defined()) { + return ConvertType(opt_tensor.value()); + } + return nullptr; +} + +inline aclIntArray *ConvertType(const c10::optional &opt_array) +{ + if (opt_array.has_value()) { + return ConvertType(opt_array.value()); + } + return nullptr; +} + +inline aclScalar *ConvertType(const c10::optional &opt_scalar) +{ + if (opt_scalar.has_value()) { + return ConvertType(opt_scalar.value()); + } + return nullptr; +} + +inline aclDataType ConvertType(const at::ScalarType scalarType) +{ + return kATenScalarTypeToAclDataTypeTable[static_cast(scalarType)]; +} + +template T ConvertType(T value) +{ + return value; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, std::index_sequence) +{ + typedef int (*OpApiFunc)(typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) +{ + static constexpr auto size = std::tuple_size::value; + return ConvertToOpApiFunc(params, opApiAddr, std::make_index_sequence{}); +} + +inline void Release(aclTensor *p) +{ + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + if (aclDestroyTensor == nullptr) { + return; + } + aclDestroyTensor(p); +} + +inline void Release(aclScalar *p) +{ + static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar); + if (aclDestroyScalar == nullptr) { + return; + } + aclDestroyScalar(p); +} + +inline void Release(aclIntArray *p) +{ + static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray); + if (aclDestroyIntArray == nullptr) { + return; + } + + aclDestroyIntArray(p); +} + +inline void Release(aclBoolArray *p) +{ + static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray); + if (aclDestroyBoolArray == nullptr) { + return; + } + + aclDestroyBoolArray(p); +} + +inline void Release(aclTensorList *p) +{ + static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList); + if (aclDestroyTensorList == nullptr) { + return; + } + + aclDestroyTensorList(p); +} + +template void Release(T value) +{ + (void)value; +} + +template void CallRelease(Tuple t, std::index_sequence) +{ + (void)std::initializer_list{(Release(std::get(t)), 0)...}; +} + +template void ReleaseConvertTypes(Tuple &t) +{ + static constexpr auto size = std::tuple_size::value; + CallRelease(t, std::make_index_sequence{}); +} + +template constexpr auto ConvertTypes(Ts &...args) +{ + return std::make_tuple(ConvertType(args)...); +} + +template auto call(Function f, Tuple t, std::index_sequence) +{ + return f(std::get(t)...); +} + +template auto call(Function f, Tuple t) +{ + static constexpr auto size = std::tuple_size::value; + return call(f, t, std::make_index_sequence{}); +} + +template void AddParamToBuf(const std::array &value) +{ + MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool)); +} + +template void AddParamToBuf(const T &value) +{ + MEMCPY_TO_BUF(&value, sizeof(T)); +} + +void AddParamToBuf(const at::Tensor &); +void AddParamToBuf(const at::Scalar &); +void AddParamToBuf(const at::IntArrayRef &); +void AddParamToBuf(const at::ArrayRef &); +void AddParamToBuf(const at::TensorList &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const at::ScalarType); +void AddParamToBuf(const string &); +void AddParamToBuf(); + +template void AddParamToBuf(const T &arg, Args &...args) +{ + AddParamToBuf(arg); + AddParamToBuf(args...); +} + +uint64_t CalcHashId(); +typedef int (*InitHugeMemThreadLocal)(void *, bool); +typedef void (*UnInitHugeMemThreadLocal)(void *, bool); +typedef void (*ReleaseHugeMem)(void *, bool); + +#define EXEC_NPU_CMD(aclnn_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ + static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ + static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ + TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, #aclnn_api, " or ", \ + #aclnn_api "GetWorkspaceSize", " not in ", GetOpApiLibName(), ", or ", GetOpApiLibName(), \ + "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + uint64_t workspace_size = 0; \ + uint64_t *workspace_size_addr = &workspace_size; \ + aclOpExecutor *executor = nullptr; \ + aclOpExecutor **executor_addr = &executor; \ + InitHugeMemThreadLocal initMemFunc = reinterpret_cast(initMemAddr); \ + UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast(unInitMemAddr); \ + if (initMemFunc) { \ + initMemFunc(nullptr, false); \ + } \ + auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + TORCH_CHECK(workspace_status == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + void *workspace_addr = nullptr; \ + if (workspace_size != 0) { \ + at::TensorOptions options = at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ + auto workspace_tensor = at::empty({workspace_size}, options.dtype(kByte)); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + auto acl_call = [converted_params, workspace_addr, workspace_size, acl_stream, executor]() -> int { \ + typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + ReleaseConvertTypes(converted_params); \ + ReleaseHugeMem releaseMemFunc = reinterpret_cast(releaseMemAddr); \ + if (releaseMemFunc) { \ + releaseMemFunc(nullptr, false); \ + } \ + return api_ret; \ + }; \ + at_npu::native::OpCommand cmd; \ + cmd.Name(#aclnn_api); \ + cmd.SetCustomHandler(acl_call); \ + cmd.Run(); \ + if (unInitMemFunc) { \ + unInitMemFunc(nullptr, false); \ + } \ + } while (false) + +#endif // PYTORCH_NPU_HELPER_HPP_ \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 90e7f03afac..bcb3111bf14 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include "torch_npu/csrc/core/npu/NPUGuard.h" #include #include "acl/acl.h" @@ -27,6 +28,7 @@ #include "ops.h" #include "utils.h" #include "mla_preprocess/op_host/mla_preprocess.h" +#include "pytorch_npu_helper.hpp" #include #include @@ -520,6 +522,41 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic cmd.Run(); return y_out; } + +std::tuple fused_deep_moe(const at::Tensor &x, const at::Tensor &expert_ids, + const at::Tensor &gmm1_permuted_weight, + const at::Tensor &gmm1_permuted_weight_scale, + const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, + const at::Tensor &expert_scales_optional, + c10::optional hcom_ep_name, + int64_t num_ranks, int64_t rank, + int64_t shared_expert_num, int64_t shared_expert_rank_num, + int64_t num_experts, int64_t global_bs, + int quant_mode) +{ + auto x_shape = x.sizes(); + auto experts_shape = expert_ids.sizes(); + int h = x_shape[1]; + int bs = experts_shape[0]; + + at::Tensor output = at::empty({bs, h}, x.options()); + + bool is_shared_expert = (rank < shared_expert_rank_num); + int64_t num_local_experts = is_shared_expert ? 1 : num_experts / (num_ranks - shared_expert_rank_num); + at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options()); + + EXEC_NPU_CMD(aclnnFusedDeepMoe, + // input + x, this->new_topk_idx, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, + gmm2_weight_scale, static_cast(nullptr), expert_scales_optional, + //attr + hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quant_mode, + global_bs, + // output + output, ep_recv_count); + return {output, ep_recv_count}; +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -576,4 +613,19 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()"); ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks); + + + ops.def( + "fused_deep_moe(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," + " Tensor gmm1_permuted_weight_scale," + " Tensor gmm2_weight, Tensor gmm2_weight_scale," + " Tensor expert_scales_optional," + " str? hcom_ep_name," + " int num_ranks, int rank," + " int shared_expert_num, int shared_expert_rank_num," + " int num_experts, int global_bs," + " int quant_mode) -> (Tensor output, Tensor ep_recv_count)" + ); + + ops.impl("fused_deep_moe", torch::kPrivateUse1, &vllm_ascend::fused_deep_moe); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index dbb056be89c..ab3df89d0ca 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -69,6 +69,31 @@ std::tuple get_masked_input_and_mask_meta( return {masked_input, mask}; } +std::tuple fused_deep_moe_meta(const at::Tensor &x, const at::Tensor &expert_ids, + const at::Tensor &gmm1_permuted_weight, + const at::Tensor &gmm1_permuted_weight_scale, + const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, + const at::Tensor &expert_scales_optional, + c10::optional hcom_ep_name, + int64_t num_ranks, int64_t rank, + int64_t shared_expert_num, int64_t shared_expert_rank_num, + int64_t num_experts, int64_t global_bs, + int quant_mode) +{ + auto x_shape = x.sizes(); + auto experts_shape = expert_ids.sizes(); + int h = x_shape[1]; + int bs = experts_shape[0]; + + at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta)); + + bool is_shared_expert = (rank < shared_expert_rank_num); + int64_t num_local_experts = is_shared_expert ? 1 : num_experts / (num_ranks - shared_expert_rank_num); + at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options().device(at::kMeta)); + + return {output, ep_recv_count}; +} + at::Tensor bgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, int64_t slice_offset, int64_t slice_size) { at::Tensor y_out = at::empty_like(y); @@ -132,5 +157,7 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); // MLA preprocess ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); + // Masked fused_deep_moe_meta meta implementation + ops.impl("fused_deep_moe", &vllm_ascend::meta::fused_deep_moe_meta); } } From 1550c865abff5b60fd813829fd85a7d97ad44683 Mon Sep 17 00:00:00 2001 From: GuoRen868 <1269192170@qq.com> Date: Mon, 10 Nov 2025 16:52:56 +0800 Subject: [PATCH 3/7] aclnn compile. --- csrc/build_aclnn.sh | 9 + csrc/custom_ops/build.sh | 73 +++++++ csrc/custom_ops/kernels/AddCustom.json | 40 ++++ .../fused_deep_moe/op_host/fused_deep_moe.cpp | 0 .../op_host/fused_deep_moe_infer.cpp | 0 .../op_host/fused_deep_moe_tiling.cpp | 0 .../op_kernel/fused_deep_moe.cpp | 0 .../fused_deep_moe/op_kernel/fused_deep_moe.h | 32 +-- .../op_kernel/fused_deep_moe_base.h | 0 .../op_kernel/fused_deep_moe_tiling.h | 0 .../op_kernel/moe_distribute_base.h | 199 +++++++++++++++++ .../pregen/aclnn/aclnn_fused_deep_moe.cpp | 66 ++++++ .../pregen/aclnn/aclnn_fused_deep_moe.h | 42 ++++ .../kernels/scripts/op_host/CMakeLists.txt | 171 +++++++++++++++ .../kernels/scripts/op_kernel/CMakeLists.txt | 8 + .../kernels}/utils/op_host/error_log.h | 0 .../op_kernel/a3/cam_moe_distribute_combine.h | 4 +- .../a3/cam_moe_distribute_dispatch.h | 4 +- .../op_kernel/operator/catlass/act/act.hpp | 0 .../operator/catlass/act/arch/arch.hpp | 0 .../catlass/act/arch/cross_core_sync.hpp | 0 .../catlass/act/arch/local_tensor_buffer.hpp | 0 .../operator/catlass/act/arch/resource.hpp | 0 .../op_kernel/operator/catlass/act/coord.hpp | 0 .../operator/catlass/act/detail/alignment.hpp | 0 .../operator/catlass/act/detail/callback.hpp | 0 .../catlass/act/detail/dependent_false.hpp | 0 .../operator/catlass/act/detail/macros.hpp | 0 .../catlass/act/detail/tag_to_layout.hpp | 0 .../act/epilogue/block/block_epilogue.hpp | 0 .../block_epilogue_per_token_dequant.hpp | 0 .../catlass/act/epilogue/dispatch_policy.hpp | 0 .../act/epilogue/tile/copy_gm_to_ub.hpp | 0 .../act/epilogue/tile/copy_ub_to_gm.hpp | 0 .../tile/tile_broadcast_inplace_by_column.hpp | 0 .../tile/tile_broadcast_inplace_by_row.hpp | 0 .../act/epilogue/tile/tile_broadcast_mul.hpp | 0 .../epilogue/tile/tile_broadcast_one_blk.hpp | 0 .../catlass/act/epilogue/tile/tile_cast.hpp | 0 .../catlass/act/epilogue/tile/tile_copy.hpp | 0 .../act/epilogue/tile/tile_elemwise_add.hpp | 0 .../act/epilogue/tile/tile_elemwise_mul.hpp | 0 .../act/epilogue/tile/tile_elemwise_muls.hpp | 0 .../act/epilogue/tile/tile_swizzle.hpp | 0 .../catlass/act/gemm/block/block_mmad.hpp | 0 ...block_mmad_preload_async_with_callback.hpp | 0 .../catlass/act/gemm/block/block_swizzle.hpp | 0 .../catlass/act/gemm/dispatch_policy.hpp | 0 .../operator/catlass/act/gemm/gemm_type.hpp | 0 .../operator/catlass/act/gemm/helper.hpp | 0 ...per_token_dequant_multistage_workspace.hpp | 0 .../catlass/act/gemm/tile/copy_gm_to_l1.hpp | 0 .../catlass/act/gemm/tile/copy_gm_to_ub.hpp | 0 .../catlass/act/gemm/tile/copy_l0c_to_gm.hpp | 0 .../catlass/act/gemm/tile/copy_l1_to_l0a.hpp | 0 .../catlass/act/gemm/tile/copy_l1_to_l0b.hpp | 0 .../catlass/act/gemm/tile/copy_ub_to_gm.hpp | 0 .../catlass/act/gemm/tile/tile_copy.hpp | 0 .../catlass/act/gemm/tile/tile_mmad.hpp | 0 .../operator/catlass/act/gemm_coord.hpp | 0 .../operator/catlass/act/gemv_coord.hpp | 0 .../operator/catlass/act/layout/layout.hpp | 0 .../operator/catlass/act/layout/matrix.hpp | 0 .../operator/catlass/act/layout/vector.hpp | 0 .../operator/catlass/act/matrix_coord.hpp | 0 .../operator/catlass/tla/int_tuple.hpp | 0 .../op_kernel/operator/catlass/tla/layout.hpp | 0 .../catlass/tla/numeric/integer_sequence.hpp | 0 .../catlass/tla/numeric/integral_constant.hpp | 0 .../operator/catlass/tla/numeric/math.hpp | 0 .../op_kernel/operator/catlass/tla/tensor.hpp | 0 .../op_kernel/operator/catlass/tla/tuple.hpp | 0 .../operator/catlass/tla/type_traits.hpp | 0 .../operator/epilogue/block/block_epilogue.h | 0 .../block_epilogue_per_token_dequant_swiglu.h | 0 .../operator/epilogue/dispatch_policy.h | 0 .../epilogue/tile/tile_stride_binary.h | 0 .../operator/epilogue/tile/tile_stride_muls.h | 0 .../operator/gemm/block/block_mmad.h | 0 ...d_preload_async_with_callback_resident_a.h | 0 .../op_kernel/operator/gemm/dispatch_policy.h | 0 ...equant_swiglu_quant_multistage_workspace.h | 2 +- csrc/custom_ops/scripts/build.sh | 58 +++++ .../custom_ops/scripts/compile_ascend_proj.sh | 65 ++++++ csrc/custom_ops/scripts/set_conf.py | 62 ++++++ csrc/pytorch_npu_helper.hpp | 204 +++++++++--------- csrc/torch_binding.cpp | 12 +- csrc/torch_binding_meta.cpp | 2 +- 88 files changed, 918 insertions(+), 135 deletions(-) create mode 100644 csrc/build_aclnn.sh create mode 100644 csrc/custom_ops/build.sh create mode 100644 csrc/custom_ops/kernels/AddCustom.json rename csrc/{ => custom_ops/kernels}/fused_deep_moe/op_host/fused_deep_moe.cpp (100%) rename csrc/{ => custom_ops/kernels}/fused_deep_moe/op_host/fused_deep_moe_infer.cpp (100%) rename csrc/{ => custom_ops/kernels}/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp (100%) rename csrc/{ => custom_ops/kernels}/fused_deep_moe/op_kernel/fused_deep_moe.cpp (100%) rename csrc/{ => custom_ops/kernels}/fused_deep_moe/op_kernel/fused_deep_moe.h (94%) rename csrc/{ => custom_ops/kernels}/fused_deep_moe/op_kernel/fused_deep_moe_base.h (100%) rename csrc/{ => custom_ops/kernels}/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h (100%) create mode 100644 csrc/custom_ops/kernels/fused_deep_moe/op_kernel/moe_distribute_base.h create mode 100644 csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.cpp create mode 100644 csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.h create mode 100644 csrc/custom_ops/kernels/scripts/op_host/CMakeLists.txt create mode 100644 csrc/custom_ops/kernels/scripts/op_kernel/CMakeLists.txt rename csrc/{ => custom_ops/kernels}/utils/op_host/error_log.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h (99%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h (99%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/act.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/arch/arch.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/arch/resource.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/coord.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/detail/alignment.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/detail/callback.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/detail/macros.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/helper.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemm_coord.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/gemv_coord.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/layout/layout.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/layout/matrix.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/layout/vector.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/act/matrix_coord.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/int_tuple.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/layout.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/numeric/math.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/tensor.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/tuple.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/catlass/tla/type_traits.hpp (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/epilogue/block/block_epilogue.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/epilogue/dispatch_policy.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/gemm/block/block_mmad.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/gemm/dispatch_policy.h (100%) rename csrc/{ => custom_ops/kernels}/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h (99%) create mode 100644 csrc/custom_ops/scripts/build.sh create mode 100644 csrc/custom_ops/scripts/compile_ascend_proj.sh create mode 100644 csrc/custom_ops/scripts/set_conf.py diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh new file mode 100644 index 00000000000..02a8946da08 --- /dev/null +++ b/csrc/build_aclnn.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# build custom ops +cd custom_ops/ +bash build.sh custom_ops -cascend910_93 + +# install custom ops +# ./output/CANN-custom_ops--linux.x86_64.run +# export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} diff --git a/csrc/custom_ops/build.sh b/csrc/custom_ops/build.sh new file mode 100644 index 00000000000..aef758462e8 --- /dev/null +++ b/csrc/custom_ops/build.sh @@ -0,0 +1,73 @@ +#!/bin/bash +SCRIPT_PATH=$(cd "$(dirname "$0")" && pwd)/$(basename "$0") +export ROOT_PATH=$(dirname "$SCRIPT_PATH") +echo ROOT_PATH: $ROOT_PATH +if [ ! -d "./build_out" ]; then + mkdir build_out +fi +export SRC_PATH="${ROOT_PATH}" +export BUILD_OUT_PATH="${ROOT_PATH}/build_out" +export SCRIPTS_PATH="${ROOT_PATH}/scripts" + +export BUILD_TYPE="Release" +MODULE_NAME="all" +MODULE_BUILD_ARG="" +IS_MODULE_EXIST=0 + +function PrintHelp() { + echo " + ./build.sh [module name] ... + If there are no parameters, all modules are compiled in default mode + module list: [custom_ops] + + opt: + -d: Enable debug + " +} + +function ProcessArg() { + while getopts "dh" opt; do + case $opt in + d) + export BUILD_TYPE="Debug" + ;; + h) + PrintHelp + exit 0 + ;; + esac + done + shift $(($OPTIND-1)) +} + +function IsModuleName() { + if [ -z "$1" ]; then + return 1 + fi + + if [[ $1 == -* ]]; then + return 1 + else + return 0 + fi +} + +if IsModuleName $@; then + MODULE_NAME=$1 + shift +else + ProcessArg $@ +fi + +if [[ "$MODULE_NAME" == "all" || "$MODULE_NAME" == "custom_ops" ]]; then + IS_MODULE_EXIST=1 + echo "./scripts/build.sh $@" + ./scripts/build.sh $@ + if [ $? -ne 0 ]; then + exit 1 + fi +fi + +if [ $IS_MODULE_EXIST -eq 0 ]; then + echo "module not exist" +fi \ No newline at end of file diff --git a/csrc/custom_ops/kernels/AddCustom.json b/csrc/custom_ops/kernels/AddCustom.json new file mode 100644 index 00000000000..dce1ed85f74 --- /dev/null +++ b/csrc/custom_ops/kernels/AddCustom.json @@ -0,0 +1,40 @@ +[ + { + "op": "AddCustom", + "language": "cpp", + "input_desc": [ + { + "name": "x", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float16" + ] + }, + { + "name": "y", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float16" + ] + } + ], + "output_desc": [ + { + "name": "z", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float16" + ] + } + ] + } +] \ No newline at end of file diff --git a/csrc/fused_deep_moe/op_host/fused_deep_moe.cpp b/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe.cpp similarity index 100% rename from csrc/fused_deep_moe/op_host/fused_deep_moe.cpp rename to csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe.cpp diff --git a/csrc/fused_deep_moe/op_host/fused_deep_moe_infer.cpp b/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_infer.cpp similarity index 100% rename from csrc/fused_deep_moe/op_host/fused_deep_moe_infer.cpp rename to csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_infer.cpp diff --git a/csrc/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp b/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp similarity index 100% rename from csrc/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp rename to csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe.cpp b/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.cpp similarity index 100% rename from csrc/fused_deep_moe/op_kernel/fused_deep_moe.cpp rename to csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.cpp diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe.h b/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.h similarity index 94% rename from csrc/fused_deep_moe/op_kernel/fused_deep_moe.h rename to csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.h index 2a6cbb68c1b..03ea1800a3b 100644 --- a/csrc/fused_deep_moe/op_kernel/fused_deep_moe.h +++ b/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.h @@ -12,22 +12,22 @@ #include "lib/matmul_intf.h" #include -#include "../utils/op_kernel/operator/catlass/act/act.hpp" -#include "../utils/op_kernel/operator/catlass/act/arch/arch.hpp" -#include "../utils/op_kernel/operator/catlass/act/layout/layout.hpp" -#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp" -#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp" -#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp" -#include "../utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp" -#include "../utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp" -#include "../utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp" -#include "../utils/op_kernel/operator/epilogue/dispatch_policy.h" -#include "../utils/op_kernel/operator/gemm/dispatch_policy.h" -#include "../utils/op_kernel/operator/epilogue/block/block_epilogue.h" -#include "../utils/op_kernel/operator/gemm/block/block_mmad.h" -#include "../utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h" - -#include "../utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h" +#include "operator/catlass/act/act.hpp" +#include "operator/catlass/act/arch/arch.hpp" +#include "operator/catlass/act/layout/layout.hpp" +#include "operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp" +#include "operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp" +#include "operator/catlass/act/epilogue/tile/tile_swizzle.hpp" +#include "operator/catlass/act/gemm/block/block_swizzle.hpp" +#include "operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp" +#include "operator/catlass/act/gemm/gemm_type.hpp" +#include "operator/epilogue/dispatch_policy.h" +#include "operator/gemm/dispatch_policy.h" +#include "operator/epilogue/block/block_epilogue.h" +#include "operator/gemm/block/block_mmad.h" +#include "operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h" + +#include "operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h" #include "fused_deep_moe_tiling.h" #include "fused_deep_moe_base.h" diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe_base.h b/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_base.h similarity index 100% rename from csrc/fused_deep_moe/op_kernel/fused_deep_moe_base.h rename to csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_base.h diff --git a/csrc/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h b/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h similarity index 100% rename from csrc/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h rename to csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/moe_distribute_base.h b/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/moe_distribute_base.h new file mode 100644 index 00000000000..b899a0e4956 --- /dev/null +++ b/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/moe_distribute_base.h @@ -0,0 +1,199 @@ +/*! + * \file moe_distribute_base.h + * \brief + */ + +#ifndef MOE_DISTRIBUTE_BASE_H +#define MOE_DISTRIBUTE_BASE_H + +/* system tick: 50MHz */ +#define CAL_US(tick) (((tick) * 2) / 100) + +/* performance macro */ +// #define USE_256_TO_1__ +#ifdef USE_256_TO_1__ +#pragma message("use 256 to 1") +#else +#define USE_FOR_OPT__ +#define DISPATCH_USE_WRITE_SHUFFLE__ +#define USE_TOKEN_COUNT_SPLIT__ +#define USE_ONE_CORE_WAIT__ + +#ifdef USE_ONE_CORE_WAIT__ +#pragma message("use one core wait") + +// #define USE_ONE_CORE_GETCUMSUM__ +#endif +#ifdef USE_FOR_OPT__ +#pragma message("use for optimization") +#define FOR_OPT_MAX_BS__ 64 +#define FOR_OPT_MAX_MOE_RANK__ 256 +#endif +// #define COMBINE_USE_DYNAMIC_QUANT +#define OPT_RANK_OFFSET 512 +#define USE_WRITE_SHUFFLE +#endif + +constexpr uint32_t LOCAL_NOTIFY_MAX_NUM = 64; +constexpr uint32_t LOCAL_STREAM_MAX_NUM = 19; +constexpr uint32_t AICPU_OP_NOTIFY_MAX_NUM = 2; +constexpr uint32_t AICPU_MAX_RANK_NUM = 128 * 1024; + +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; // HccltagLocalResV2 +}; + +enum class rtFloatOverflowMode_t { + RT_OVERFLOW_MODE_SATURATION = 0, + RT_OVERFLOW_MODE_INFNAN, + RT_OVERFLOW_MODE_UNDEF, +}; + +struct AlgoTopoInfo { + uint32_t userRank; // RankID + uint32_t userRankSize; // Rank Number + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; // TopoType + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interHccsDisable = false; + rtFloatOverflowMode_t floatOverflowMode = rtFloatOverflowMode_t::RT_OVERFLOW_MODE_UNDEF; + uint32_t multiQpThreshold = 512; +}; + +struct HcclMC2WorkSpace { + uint64_t workSpace; + uint64_t workSpaceSize; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HDCommunicateParams { + uint64_t hostAddr{0}; + uint64_t deviceAddr{0}; + uint64_t readCacheAddr{0}; + uint32_t devMemSize{0}; + uint32_t buffLen{0}; + uint32_t flag{0}; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParam { + // local resource + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; // usrrankid + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + // aicore detect remote window + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + + // config parameters + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[AICPU_MAX_RANK_NUM]; + + // communicate retry + HDCommunicateParams kfcControlTransferH2DParams; + HDCommunicateParams kfcStatusTransferD2HParams; + uint64_t tinyMem; // for all2all + uint64_t tinyMemSize; + // zero-copy + uint64_t zeroCopyHeadPtr; + uint64_t zeroCopyTailPtr; + uint64_t zeroCopyRingBuffer; + uint64_t zeroCopyIpcPtrs[16]; + uint32_t zeroCopyDevicePhyId[16]; + + bool utraceStatusFlag; +}; + +#endif // MOE_DISTRIBUTE_BASE_H diff --git a/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.cpp b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.cpp new file mode 100644 index 00000000000..be88b904544 --- /dev/null +++ b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.cpp @@ -0,0 +1,66 @@ +#include +#include "graph/types.h" +#include "aclnn/opdev/platform.h" +#include "aclnn_fused_deep_moe.h" +#include "aclnnInner_fused_deep_moe.h" + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +#ifdef __cplusplus +extern "C" { +#endif + +aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize( + const aclTensor *x, + const aclTensor *expertIds, + const aclTensor *gmm1PermutedWeight, + const aclTensor *gmm1PermutedWeightScale, + const aclTensor *gmm2Weight, + const aclTensor *gmm2WeightScale, + const aclTensor *expertSmoothScalesOptional, + const aclTensor *expertScalesOptional, + char *groupEp, + int64_t epRankSize, + int64_t epRankId, + int64_t moeExpertNum, + int64_t shareExpertNum, + int64_t shareExpertRankNum, + int64_t quantMode, + int64_t globalBs, + const aclTensor *output, + const aclTensor *epRecvCount, + uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + return aclnnInnerFusedDeepMoeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, + gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, + epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs, + output, epRecvCount, workspaceSize, executor); +} + +aclnnStatus aclnnFusedDeepMoe( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU); + } else { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + } + return aclnnInnerFusedDeepMoe(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif + + diff --git a/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.h b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.h new file mode 100644 index 00000000000..435ec98e50a --- /dev/null +++ b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.h @@ -0,0 +1,42 @@ +#ifndef FUSED_DEEP_MOE +#define FUSED_DEEP_MOE + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize( + const aclTensor *x, + const aclTensor *expertIds, + const aclTensor *gmm1PermutedWeight, + const aclTensor *gmm1PermutedWeightScale, + const aclTensor *gmm2Weight, + const aclTensor *gmm2WeightScale, + const aclTensor *expertSmoothScalesOptional, + const aclTensor *expertScalesOptional, + char *groupEp, + int64_t epRankSize, + int64_t epRankId, + int64_t moeExpertNum, + int64_t shareExpertNum, + int64_t shareExpertRankNum, + int64_t quantMode, + int64_t globalBs, + const aclTensor *output, + const aclTensor *epRecvCount, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +__attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoe( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/csrc/custom_ops/kernels/scripts/op_host/CMakeLists.txt b/csrc/custom_ops/kernels/scripts/op_host/CMakeLists.txt new file mode 100644 index 00000000000..d906f795b7a --- /dev/null +++ b/csrc/custom_ops/kernels/scripts/op_host/CMakeLists.txt @@ -0,0 +1,171 @@ +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} ops_srcs) + +opbuild(OPS_SRC ${ops_srcs} + OUT_DIR ${ASCEND_AUTOGEN_PATH} +) + +file(GLOB group_proto_src ${ASCEND_AUTOGEN_PATH}/group_proto/*.cc) + +add_library(cust_op_proto SHARED + $<$:${group_proto_src}> + ${ops_srcs} + ${ASCEND_AUTOGEN_PATH}/op_proto.cc +) +target_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB) +target_compile_options(cust_op_proto PRIVATE + -fvisibility=hidden +) +if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_op_proto PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) +endif() +target_link_libraries(cust_op_proto PRIVATE + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive +) +set_target_properties(cust_op_proto PROPERTIES OUTPUT_NAME + cust_opsproto_rt2.0 +) +file(GLOB fallback_src ${ASCEND_AUTOGEN_PATH}/fallback_*.cpp) +add_library(cust_optiling SHARED ${ops_srcs}) +if (${fallback_src}) + target_sources(cust_optiling PRIVATE ${fallback_src}) +endif() +target_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB) +target_compile_options(cust_optiling PRIVATE + -fvisibility=hidden +) +if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_optiling PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) +endif() +target_link_libraries(cust_optiling PRIVATE + nnopbase + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive +) +set_target_properties(cust_optiling PROPERTIES OUTPUT_NAME + cust_opmaster_rt2.0 +) + +file(GLOB pregen_file "../pregen/aclnn/*") +file(COPY ${pregen_file} DESTINATION ${ASCEND_AUTOGEN_PATH}) +file(GLOB aclnn_src ${ASCEND_AUTOGEN_PATH}/aclnn*.cpp) +file(GLOB aclnn_inc ${ASCEND_AUTOGEN_PATH}/aclnn_*.h) + +if(NOT ASCEND_PACK_SHARED_LIBRARY) + add_library(cust_opapi SHARED ${aclnn_src}) +else() + file(GLOB op_registry ${ASCEND_AUTOGEN_PATH}/custom_op_registry.cpp) + add_library(cust_opapi SHARED ${aclnn_src} ${op_registry}) + target_compile_definitions(cust_opapi PRIVATE ACLNN_WITH_BINARY) +endif() + +target_include_directories(cust_opapi PRIVATE $ENV{ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform/) + +if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_opapi PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) +endif() +if(NOT ASCEND_PACK_SHARED_LIBRARY) + target_link_libraries(cust_opapi PRIVATE intf_pub ascendcl nnopbase) +else() + add_library(cust_op_proto_obj OBJECT + $<$:${group_proto_src}> + ${ops_srcs} + ${ASCEND_AUTOGEN_PATH}/op_proto.cc + ) + target_compile_definitions(cust_op_proto_obj PRIVATE OP_PROTO_LIB) + target_compile_options(cust_op_proto_obj PRIVATE + -fvisibility=hidden + ) + if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_op_proto_obj PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) + endif() + target_link_libraries(cust_op_proto_obj PRIVATE + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + ) + add_library(cust_optiling_obj OBJECT ${ops_srcs}) + target_compile_definitions(cust_optiling_obj PRIVATE OP_TILING_LIB) + target_compile_options(cust_optiling_obj PRIVATE + -fvisibility=hidden + ) + if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_optiling_obj PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) + endif() + target_link_libraries(cust_optiling_obj PRIVATE + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + ) + target_compile_options(cust_opapi PRIVATE -DLOG_CPP) + target_include_directories(cust_opapi INTERFACE ${CMAKE_SOURCE_DIR}/build_out/library/) + target_link_libraries(cust_opapi PRIVATE intf_pub ascendcl nnopbase cust_optiling_obj cust_op_proto_obj ascend_opregistry ascend_kernels) + add_dependencies(cust_opapi ascend_opregistry) +endif() + +add_custom_target(optiling_compat ALL + COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$ + ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so +) +if(NOT ASCEND_PACK_SHARED_LIBRARY) + install(TARGETS cust_op_proto + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) + install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h + DESTINATION packages/vendors/${vendor_name}/op_proto/inc) + file(GLOB GROUP_PROTO_HEADERS ${ASCEND_AUTOGEN_PATH}/group_proto/*.h) + if (GROUP_PROTO_HEADERS) + install(FILES ${GROUP_PROTO_HEADERS} + DESTINATION packages/vendors/${vendor_name}/op_proto/inc) + endif() + install(TARGETS cust_optiling + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling) + install(TARGETS cust_opapi + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_api/lib) + install(FILES ${aclnn_inc} + DESTINATION packages/vendors/${vendor_name}/op_api/include) +else() + file(GLOB group_inc ${ASCEND_AUTOGEN_PATH}/group_proto/*.h) + install(TARGETS cust_opapi + LIBRARY DESTINATION op_api/lib) + install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h + DESTINATION op_api/include) + install(FILES ${group_inc} + DESTINATION op_api/include) + install(FILES ${aclnn_inc} + DESTINATION op_api/include) +endif() \ No newline at end of file diff --git a/csrc/custom_ops/kernels/scripts/op_kernel/CMakeLists.txt b/csrc/custom_ops/kernels/scripts/op_kernel/CMakeLists.txt new file mode 100644 index 00000000000..20e88d4bcca --- /dev/null +++ b/csrc/custom_ops/kernels/scripts/op_kernel/CMakeLists.txt @@ -0,0 +1,8 @@ +# set custom compile options +if ("${CMAKE_BUILD_TYPE}x" STREQUAL "Debugx") + add_ops_compile_options(ALL OPTIONS -g -O0) +endif() + +add_ops_compile_options(ALL OPTIONS -DASCENDC_DUMP=0 --cce-auto-sync=off) + +add_kernels_compile() \ No newline at end of file diff --git a/csrc/utils/op_host/error_log.h b/csrc/custom_ops/kernels/utils/op_host/error_log.h similarity index 100% rename from csrc/utils/op_host/error_log.h rename to csrc/custom_ops/kernels/utils/op_host/error_log.h diff --git a/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h similarity index 99% rename from csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h index b35e6140183..9ccb5867370 100644 --- a/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h +++ b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h @@ -12,8 +12,8 @@ #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" -#include "../../../../../../op_kernel/fused_deep_moe_base.h" -#include "../../../../../../op_kernel/fused_deep_moe_tiling.h" +#include "../../../../fused_deep_moe_base.h" +#include "../../../../fused_deep_moe_tiling.h" namespace MoeDistributeCombineImpl { constexpr uint8_t BUFFER_NUM = 2; // multi-buf diff --git a/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h similarity index 99% rename from csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h index cf608bb4083..a2f1febee9b 100644 --- a/csrc/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h +++ b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h @@ -13,8 +13,8 @@ #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" -#include "../../../../../../op_kernel/fused_deep_moe_base.h" -#include "../../../../../../op_kernel/fused_deep_moe_tiling.h" +#include "../../../../fused_deep_moe_base.h" +#include "../../../../fused_deep_moe_tiling.h" namespace MoeDistributeDispatchImpl { constexpr uint8_t BUFFER_NUM = 2; // 多buf diff --git a/csrc/utils/op_kernel/operator/catlass/act/act.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/act.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/act.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/act.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/arch.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/arch.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/arch/arch.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/arch.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/cross_core_sync.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/local_tensor_buffer.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/arch/resource.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/resource.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/arch/resource.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/arch/resource.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/coord.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/coord.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/coord.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/coord.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/alignment.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/alignment.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/detail/alignment.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/alignment.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/callback.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/callback.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/detail/callback.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/callback.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/dependent_false.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/macros.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/macros.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/detail/macros.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/macros.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/detail/tag_to_layout.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/block/block_epilogue_per_token_dequant.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/dispatch_policy.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_gm_to_ub.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/copy_ub_to_gm.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_cast.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_copy.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_add.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_mul.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_elemwise_muls.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/block/block_mmad_preload_async_with_callback.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/dispatch_policy.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/helper.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/helper.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/helper.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/helper.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_l1.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_gm_to_ub.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l0c_to_gm.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0a.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_l1_to_l0b.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/copy_ub_to_gm.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/tile_copy.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm/tile/tile_mmad.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemm_coord.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm_coord.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemm_coord.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemm_coord.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/gemv_coord.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemv_coord.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/gemv_coord.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/gemv_coord.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/layout/layout.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/layout/layout.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/layout/layout.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/layout/layout.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/layout/matrix.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/layout/matrix.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/layout/matrix.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/layout/matrix.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/layout/vector.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/layout/vector.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/layout/vector.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/layout/vector.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/act/matrix_coord.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/matrix_coord.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/act/matrix_coord.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/act/matrix_coord.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/int_tuple.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/int_tuple.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/int_tuple.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/int_tuple.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/layout.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/layout.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/layout.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/layout.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/numeric/math.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/numeric/math.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/numeric/math.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/numeric/math.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/tensor.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/tensor.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/tensor.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/tensor.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/tuple.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/tuple.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/tuple.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/tuple.hpp diff --git a/csrc/utils/op_kernel/operator/catlass/tla/type_traits.hpp b/csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/type_traits.hpp similarity index 100% rename from csrc/utils/op_kernel/operator/catlass/tla/type_traits.hpp rename to csrc/custom_ops/kernels/utils/op_kernel/operator/catlass/tla/type_traits.hpp diff --git a/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/block/block_epilogue.h similarity index 100% rename from csrc/utils/op_kernel/operator/epilogue/block/block_epilogue.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/block/block_epilogue.h diff --git a/csrc/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h similarity index 100% rename from csrc/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h diff --git a/csrc/utils/op_kernel/operator/epilogue/dispatch_policy.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/dispatch_policy.h similarity index 100% rename from csrc/utils/op_kernel/operator/epilogue/dispatch_policy.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/dispatch_policy.h diff --git a/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h similarity index 100% rename from csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h diff --git a/csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h similarity index 100% rename from csrc/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h diff --git a/csrc/utils/op_kernel/operator/gemm/block/block_mmad.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/block/block_mmad.h similarity index 100% rename from csrc/utils/op_kernel/operator/gemm/block/block_mmad.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/block/block_mmad.h diff --git a/csrc/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h similarity index 100% rename from csrc/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h diff --git a/csrc/utils/op_kernel/operator/gemm/dispatch_policy.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/dispatch_policy.h similarity index 100% rename from csrc/utils/op_kernel/operator/gemm/dispatch_policy.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/dispatch_policy.h diff --git a/csrc/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h similarity index 99% rename from csrc/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h rename to csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index 53a4dac75d8..32f42f94255 100644 --- a/csrc/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -18,7 +18,7 @@ #include "../../catlass/act/epilogue/tile/tile_swizzle.hpp" #include "../../catlass/act/epilogue/tile/tile_copy.hpp" -#include "../../../../../op_kernel/fused_deep_moe_base.h" +#include "../../../fused_deep_moe_base.h" constexpr uint32_t STATE_OFFSET = 512; constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; diff --git a/csrc/custom_ops/scripts/build.sh b/csrc/custom_ops/scripts/build.sh new file mode 100644 index 00000000000..bf46850d4c7 --- /dev/null +++ b/csrc/custom_ops/scripts/build.sh @@ -0,0 +1,58 @@ +#!/bin/bash +export MODULE_NAME="custom_ops" +export MODULE_SRC_PATH="${SRC_PATH}" +export MODULE_SCRIPTS_PATH="${SCRIPTS_PATH}/" +export MODULE_BUILD_OUT_PATH="${BUILD_OUT_PATH}/${MODULE_NAME}" +IS_EXTRACT=0 +SOC_VERSION="all" +ENABLE_SRC_BUILD=1 + +PrintHelp() { + echo " + ./build.sh custom_ops ... + -x Extract the run package + -c Target SOC VERSION + Support Soc: [ascend910_93, ascend910b4] + -d Enable debug + -r Enable code coverage + " +} + +while getopts "c:xdh" opt; do + case $opt in + c) + SOC_VERSION=$OPTARG + ;; + x) + IS_EXTRACT=1 + ;; + d) + export BUILD_TYPE="Debug" + ;; + h) + PrintHelp + exit 0 + ;; + esac +done + +if [ ! -d "$BUILD_OUT_PATH/${MODULE_NAME}" ]; then + mkdir $BUILD_OUT_PATH/${MODULE_NAME} +fi + +# 目前whl包和UT的编译暂时需要先将CAM算子包并安装到环境 +# 在编译whl包和UT时屏蔽算子包编译,加快编译速度 +if [ $ENABLE_SRC_BUILD -eq 1 ]; then + + if [ ! -d "./build_out/custom_ops/run/" ]; then + mkdir ${MODULE_BUILD_OUT_PATH}/run + fi + if [[ "$SOC_VERSION" == "all" ]]; then + bash $MODULE_SCRIPTS_PATH/compile_ascend_proj.sh $MODULE_SRC_PATH ascend910_93 $IS_EXTRACT $BUILD_TYPE + else + bash $MODULE_SCRIPTS_PATH/compile_ascend_proj.sh $MODULE_SRC_PATH $SOC_VERSION $IS_EXTRACT $BUILD_TYPE + fi + if [ $? -ne 0 ]; then + exit 1 + fi +fi \ No newline at end of file diff --git a/csrc/custom_ops/scripts/compile_ascend_proj.sh b/csrc/custom_ops/scripts/compile_ascend_proj.sh new file mode 100644 index 00000000000..a462ce1f51a --- /dev/null +++ b/csrc/custom_ops/scripts/compile_ascend_proj.sh @@ -0,0 +1,65 @@ +#!/bin/bash +CopyOps() { + local src_dir="$1" # 源目录 + local dst_dir="$2" # 目标目录 + + # 确保目标目录的ophost和opkernel存在 + mkdir -p "$dst_dir/op_host" "$dst_dir/op_kernel" + + # 遍历源目录下所有直接子目录 (包括含空格的目录) + find "$src_dir" -mindepth 1 -maxdepth 1 -type d -print0 | while IFS= read -r -d '' subdir; do + # 检查子目录是否存在(双重验证) + if [ -d "$subdir" ]; then + # 处理op_host目录 + if [ -d "$subdir/op_host" ]; then + cp -rf "$subdir/op_host/"* "$dst_dir/op_host/" + fi + + # 处理op_kernel目录 + if [ -d "$subdir/op_kernel" ]; then + cp -rf "$subdir/op_kernel/"* "$dst_dir/op_kernel/" + fi + fi + done +} + +# 构建算子工程并将其产物传到指定地点 +BuildAscendProj() { + local os_id=$(grep ^ID= /etc/os-release | cut -d= -f2 | tr -d '"') + local arch=$(uname -m) + local soc_version=$2 + local is_extract=$3 + local build_type=$4 + local proj_name="kernels_${soc_version}_proj" + # 修改默认算子名 + export OPS_PROJECT_NAME=aclnnInner + # 进入编译路径 + cd $1 + + if [ -d "./${proj_name}" ]; then + rm -rf ${proj_name} + fi + echo "msopgen gen -i ./kernels/AddCustom.json -c ai_core-${soc_version} -f pytorch -lan cpp -out ${proj_name}" + msopgen gen -i ./kernels/AddCustom.json -c ai_core-${soc_version} -f pytorch -lan cpp -out ${proj_name} + rm -rf ./${proj_name}/op_host/add_custom* + rm -rf ./${proj_name}/op_kernel/add_custom* + CopyOps "./kernels" "./${proj_name}" + python $SCRIPTS_PATH/custom_ops/set_conf.py ./${proj_name}/CMakePresets.json $build_type True CAM + cp -rf ./kernels/pregen ./${proj_name} + + source $ASCEND_HOME_PATH/bin/setenv.bash + cd ${proj_name} + ./build.sh + # 根据is_extract判断是否抽取run包 + if [ $is_extract -eq 1 ]; then + if [ ! -d "$BUILD_OUT_PATH/custom_ops/extract" ]; then + mkdir -p "$BUILD_OUT_PATH/custom_ops/extract" + fi + mkdir ${BUILD_OUT_PATH}/custom_ops/extract/${soc_version} + build_out/*.run --extract=${BUILD_OUT_PATH}/custom_ops/extract/${soc_version} + else + cp build_out/*.run ${BUILD_OUT_PATH}/custom_ops/run/CAM_${soc_version}_${os_id}_${arch}.run + fi +} + +BuildAscendProj $1 $2 $3 $4 \ No newline at end of file diff --git a/csrc/custom_ops/scripts/set_conf.py b/csrc/custom_ops/scripts/set_conf.py new file mode 100644 index 00000000000..d5cb4f532ec --- /dev/null +++ b/csrc/custom_ops/scripts/set_conf.py @@ -0,0 +1,62 @@ +import json +import sys +import argparse + +def update_json_path(args): + """ + Update configuration items in the JSON file + """ + try: + # read the json file + with open(args.file_path, 'r') as f: + data = json.load(f) + + # Iterate through the first configuration item in configurePresets array (assuming the target is the first one) + configure_preset = data.get('configurePresets', [{}])[0] + cache_variables = configure_preset.get('cacheVariables', {}) + + + # Modify the value of CMAKE_BUILD_TYPE + if 'CMAKE_BUILD_TYPE' in cache_variables: + cache_variables['CMAKE_BUILD_TYPE']['value'] = args.build_type + else: + print("CMAKE_BUILD_TYPE field not found") + sys.exit(1) + + # Modify the value of ENABLE_SOURCE_PACKAGE + if 'ENABLE_SOURCE_PACKAGE' in cache_variables: + cache_variables['ENABLE_SOURCE_PACKAGE']['value'] = args.enable_source + else: + print("ENABLE_SOURCE_PACKAGE field not found") + sys.exit(1) + + # Modify the value of vendor_name + if 'vendor_name' in cache_variables: + cache_variables['vendor_name']['value'] = args.vendor_name + else: + print("vendor_name field not found") + sys.exit(1) + + # write back to JSON file (preserve indentation format) + with open(args.file_path, 'w') as f: + json.dump(data, f, indent=4) + print("Successfully updated parameters") + + except FileNotFoundError: + print(f"File not found: {args.file_path}") + except json.JSONDecodeError: + print(f"JSON format error: {args.file_path}") + except Exception as e: + print(f"An error occurred: {str(e)}") + +if __name__ == "__main__": + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Modify configuration items in CMakePresets.json") + parser.add_argument("file_path", help="Path to the JSON file") + parser.add_argument("build_type", help="Build type (e.g., Debug or Release)") + parser.add_argument("enable_source", help="Enable source package generation (true/false)") + parser.add_argument("vendor_name", help="Specify the custom operator directory name") + + args = parser.parse_args() + + update_json_path(args) \ No newline at end of file diff --git a/csrc/pytorch_npu_helper.hpp b/csrc/pytorch_npu_helper.hpp index 5ce9725d234..d9d32a650b4 100644 --- a/csrc/pytorch_npu_helper.hpp +++ b/csrc/pytorch_npu_helper.hpp @@ -5,8 +5,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -14,13 +14,6 @@ #include #include -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/core/npu/NPUStream.h" -#include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/interface/EnvVariables.h" -#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" - #define NPU_NAME_SPACE at_npu::native #define __FILENAME__ (strrchr("/" __FILE__, '/') + 1) @@ -77,19 +70,23 @@ extern thread_local int g_hashOffset; _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) constexpr aclDataType kATenScalarTypeToAclDataTypeTable[static_cast(at::ScalarType::NumOptions) + 1] = { -#define DEFINE_ENUM(_1, n) n, +#define DEFINE_ENUM(_1, n) (n), AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) #undef DEFINE_ENUM }; #define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) -#define MEMCPY_TO_BUF(data_expression, size_expression) \ - if (g_hashOffset + (size_expression) > kHashBufSize) { \ - g_hashOffset = kHashBufMaxSize; \ - return; \ - } \ - memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \ +#define MEMCPY_TO_BUF(data_expression, size_expression) \ + if (g_hashOffset + (size_expression) > kHashBufSize) { \ + g_hashOffset = kHashBufMaxSize; \ + return; \ + } \ + int ret = memcpy_s(g_hashBuf + g_hashOffset, data_expression, size_expression); \ + if (ret != 0) { \ + ASCEND_LOGW("memcpy_s failed, ret = %d\n", ret); \ + return; \ + } \ g_hashOffset += size_expression; inline const char *GetOpApiLibName(void) @@ -141,6 +138,9 @@ inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) { c10::Scalar expScalar; const at::Tensor *aclInput = &tensor; + if (aclInput == nullptr || aclInput->data_ptr() == nullptr) { + return expScalar; + } if (aclInput->scalar_type() == at::ScalarType::Double) { double value = *(double *)aclInput->data_ptr(); c10::Scalar scalar(value); @@ -208,50 +208,28 @@ inline aclTensor *ConvertType(const at::Tensor &at_tensor) aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; TORCH_CHECK(acl_data_type != ACL_DT_UNDEFINED, std::string(c10::toString(scalar_data_type)) + " has not been supported") - c10::SmallVector storageDims; - // if acl_data_type is ACL_STRING, storageDims is empty. auto itemsize = at_tensor.itemsize(); if (itemsize == 0) { AT_ERROR("When ConvertType, tensor item size of cannot be zero."); return nullptr; } - if (acl_data_type != ACL_STRING) { - storageDims.push_back(at_tensor.storage().nbytes() / itemsize); - } - const auto dimNum = at_tensor.sizes().size(); - aclFormat format = ACL_FORMAT_ND; - switch (dimNum) { - case 3: - // 适配matmul_allreduce_add_rmsnorm.py算子入参 - format = ACL_FORMAT_ND; - break; - case 4: - format = ACL_FORMAT_NCHW; - break; - case 5: - format = ACL_FORMAT_NCDHW; - break; - default: - format = ACL_FORMAT_ND; + std::vector strides(dimNum, 1); + for (int64_t i = dimNum - 2; i >= 0; i--) { + strides[i] = at_tensor.sizes().data()[i + 1] * strides[i + 1]; } + aclFormat format = ACL_FORMAT_ND; + // 适配fused_deep_moe算子的weight入参 if (acl_data_type == ACL_INT8 && dimNum == 3) { format = ACL_FORMAT_FRACTAL_NZ; } - if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) { - c10::Scalar expScalar = ConvertTensorToScalar(at_tensor); - at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type); - return aclCreateTensor(aclInput.sizes().data(), aclInput.sizes().size(), acl_data_type, - aclInput.strides().data(), aclInput.storage_offset(), format, storageDims.data(), - storageDims.size(), const_cast(aclInput.storage().data())); - } - auto acl_tensor = - aclCreateTensor(at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type, at_tensor.strides().data(), - at_tensor.storage_offset(), format, storageDims.data(), storageDims.size(), + aclCreateTensor(at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type, strides.data(), + 0, format, at_tensor.sizes().data(), at_tensor.sizes().size(), const_cast(at_tensor.storage().data())); + return acl_tensor; } @@ -305,7 +283,8 @@ inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) return array; } -template inline aclBoolArray *ConvertType(const std::array &value) +template +inline aclBoolArray *ConvertType(const std::array &value) { static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); if (aclCreateBoolArray == nullptr) { @@ -371,7 +350,8 @@ inline aclDataType ConvertType(const at::ScalarType scalarType) return kATenScalarTypeToAclDataTypeTable[static_cast(scalarType)]; } -template T ConvertType(T value) +template +T ConvertType(T value) { return value; } @@ -384,7 +364,8 @@ auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, std::index_sequenc return func; } -template auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) { static constexpr auto size = std::tuple_size::value; return ConvertToOpApiFunc(params, opApiAddr, std::make_index_sequence{}); @@ -438,44 +419,52 @@ inline void Release(aclTensorList *p) aclDestroyTensorList(p); } -template void Release(T value) +template +void Release(T value) { (void)value; } -template void CallRelease(Tuple t, std::index_sequence) +template +void CallRelease(Tuple t, std::index_sequence) { (void)std::initializer_list{(Release(std::get(t)), 0)...}; } -template void ReleaseConvertTypes(Tuple &t) +template +void ReleaseConvertTypes(Tuple &t) { static constexpr auto size = std::tuple_size::value; CallRelease(t, std::make_index_sequence{}); } -template constexpr auto ConvertTypes(Ts &...args) +template +constexpr auto ConvertTypes(Ts &...args) { return std::make_tuple(ConvertType(args)...); } -template auto call(Function f, Tuple t, std::index_sequence) +template +auto call(Function f, Tuple t, std::index_sequence) { return f(std::get(t)...); } -template auto call(Function f, Tuple t) +template +auto call(Function f, Tuple t) { static constexpr auto size = std::tuple_size::value; return call(f, t, std::make_index_sequence{}); } -template void AddParamToBuf(const std::array &value) +template +void AddParamToBuf(const std::array &value) { MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool)); } -template void AddParamToBuf(const T &value) +template +void AddParamToBuf(const T &value) { MEMCPY_TO_BUF(&value, sizeof(T)); } @@ -492,7 +481,8 @@ void AddParamToBuf(const at::ScalarType); void AddParamToBuf(const string &); void AddParamToBuf(); -template void AddParamToBuf(const T &arg, Args &...args) +template +void AddParamToBuf(const T &arg, Args &...args) { AddParamToBuf(arg); AddParamToBuf(args...); @@ -503,55 +493,55 @@ typedef int (*InitHugeMemThreadLocal)(void *, bool); typedef void (*UnInitHugeMemThreadLocal)(void *, bool); typedef void (*ReleaseHugeMem)(void *, bool); -#define EXEC_NPU_CMD(aclnn_api, ...) \ - do { \ - static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ - static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ - static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ - static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ - static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ - TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, #aclnn_api, " or ", \ - #aclnn_api "GetWorkspaceSize", " not in ", GetOpApiLibName(), ", or ", GetOpApiLibName(), \ - "not found."); \ - auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ - uint64_t workspace_size = 0; \ - uint64_t *workspace_size_addr = &workspace_size; \ - aclOpExecutor *executor = nullptr; \ - aclOpExecutor **executor_addr = &executor; \ - InitHugeMemThreadLocal initMemFunc = reinterpret_cast(initMemAddr); \ - UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast(unInitMemAddr); \ - if (initMemFunc) { \ - initMemFunc(nullptr, false); \ - } \ - auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ - static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ - auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ - TORCH_CHECK(workspace_status == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ - void *workspace_addr = nullptr; \ - if (workspace_size != 0) { \ - at::TensorOptions options = at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ - auto workspace_tensor = at::empty({workspace_size}, options.dtype(kByte)); \ - workspace_addr = const_cast(workspace_tensor.storage().data()); \ - } \ - auto acl_call = [converted_params, workspace_addr, workspace_size, acl_stream, executor]() -> int { \ - typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, const aclrtStream); \ - OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ - auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ - TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ - ReleaseConvertTypes(converted_params); \ - ReleaseHugeMem releaseMemFunc = reinterpret_cast(releaseMemAddr); \ - if (releaseMemFunc) { \ - releaseMemFunc(nullptr, false); \ - } \ - return api_ret; \ - }; \ - at_npu::native::OpCommand cmd; \ - cmd.Name(#aclnn_api); \ - cmd.SetCustomHandler(acl_call); \ - cmd.Run(); \ - if (unInitMemFunc) { \ - unInitMemFunc(nullptr, false); \ - } \ +#define EXEC_NPU_CMD(aclnn_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ + static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ + static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ + TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, #aclnn_api, " or ", \ + #aclnn_api "GetWorkspaceSize", " not in ", GetOpApiLibName(), ", or ", GetOpApiLibName(), \ + "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + uint64_t workspace_size = 0; \ + uint64_t *workspace_size_addr = &workspace_size; \ + aclOpExecutor *executor = nullptr; \ + aclOpExecutor **executor_addr = &executor; \ + InitHugeMemThreadLocal initMemFunc = reinterpret_cast(initMemAddr); \ + UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast(unInitMemAddr); \ + if (initMemFunc) { \ + initMemFunc(nullptr, false); \ + } \ + auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + TORCH_CHECK(workspace_status == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + void *workspace_addr = nullptr; \ + if (workspace_size != 0) { \ + at::TensorOptions options = at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ + auto workspace_tensor = at::empty({static_cast(workspace_size)}, options.dtype(c10::kByte)); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + auto acl_call = [converted_params, workspace_addr, workspace_size, acl_stream, executor]() -> int { \ + typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + ReleaseConvertTypes(converted_params); \ + ReleaseHugeMem releaseMemFunc = reinterpret_cast(releaseMemAddr); \ + if (releaseMemFunc) { \ + releaseMemFunc(nullptr, false); \ + } \ + return api_ret; \ + }; \ + at_npu::native::OpCommand cmd; \ + cmd.Name(#aclnn_api); \ + cmd.SetCustomHandler(acl_call); \ + cmd.Run(); \ + if (unInitMemFunc) { \ + unInitMemFunc(nullptr, false); \ + } \ } while (false) -#endif // PYTORCH_NPU_HELPER_HPP_ \ No newline at end of file +#endif // PYTORCH_NPU_HELPER_HPP_ diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index bcb3111bf14..d31b4eb9c5c 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -532,7 +532,7 @@ std::tuple fused_deep_moe(const at::Tensor &x, const at: int64_t num_ranks, int64_t rank, int64_t shared_expert_num, int64_t shared_expert_rank_num, int64_t num_experts, int64_t global_bs, - int quant_mode) + int64_t quant_mode) { auto x_shape = x.sizes(); auto experts_shape = expert_ids.sizes(); @@ -547,7 +547,7 @@ std::tuple fused_deep_moe(const at::Tensor &x, const at: EXEC_NPU_CMD(aclnnFusedDeepMoe, // input - x, this->new_topk_idx, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, + x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, static_cast(nullptr), expert_scales_optional, //attr hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quant_mode, @@ -621,10 +621,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " Tensor gmm2_weight, Tensor gmm2_weight_scale," " Tensor expert_scales_optional," " str? hcom_ep_name," - " int num_ranks, int rank," - " int shared_expert_num, int shared_expert_rank_num," - " int num_experts, int global_bs," - " int quant_mode) -> (Tensor output, Tensor ep_recv_count)" + " int64_t num_ranks, int64_t rank," + " int64_t shared_expert_num, int64_t shared_expert_rank_num," + " int64_t num_experts, int64_t global_bs," + " int64_t quant_mode) -> (Tensor output, Tensor ep_recv_count)" ); ops.impl("fused_deep_moe", torch::kPrivateUse1, &vllm_ascend::fused_deep_moe); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index ab3df89d0ca..0ba450ec3a9 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -78,7 +78,7 @@ std::tuple fused_deep_moe_meta(const at::Tensor &x, cons int64_t num_ranks, int64_t rank, int64_t shared_expert_num, int64_t shared_expert_rank_num, int64_t num_experts, int64_t global_bs, - int quant_mode) + int64_t quant_mode) { auto x_shape = x.sizes(); auto experts_shape = expert_ids.sizes(); From 6deb62dc7f7ee3e6c59e5524a18ea270fa962b17 Mon Sep 17 00:00:00 2001 From: GuoRen868 <1269192170@qq.com> Date: Mon, 10 Nov 2025 19:27:58 +0800 Subject: [PATCH 4/7] build and test --- csrc/torch_binding.cpp | 8 ++++---- setup.py | 25 +++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index d31b4eb9c5c..180044eeb88 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -621,10 +621,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " Tensor gmm2_weight, Tensor gmm2_weight_scale," " Tensor expert_scales_optional," " str? hcom_ep_name," - " int64_t num_ranks, int64_t rank," - " int64_t shared_expert_num, int64_t shared_expert_rank_num," - " int64_t num_experts, int64_t global_bs," - " int64_t quant_mode) -> (Tensor output, Tensor ep_recv_count)" + " int num_ranks, int rank," + " int shared_expert_num, int shared_expert_rank_num," + " int num_experts, int global_bs," + " int quant_mode) -> (Tensor output, Tensor ep_recv_count)" ); ops.impl("fused_deep_moe", torch::kPrivateUse1, &vllm_ascend::fused_deep_moe); diff --git a/setup.py b/setup.py index 5a823e7abc5..2c41bc67459 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ from sysconfig import get_paths from typing import Dict, List -from setuptools import Extension, find_packages, setup +from setuptools import Command, Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py from setuptools.command.develop import develop @@ -102,6 +102,24 @@ def run(self): f"Generated _build_info.py with SOC version: {soc_version}") super().run() +class build_and_install_aclnn(Command): + description = "Build and install AclNN by running build_aclnn.sh" + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + try: + print("Running bash build_aclnn.sh ...") + subprocess.check_call(["bash", "csrc/build_aclnn.sh", ROOT_DIR]) + print("buid_aclnn.sh executed successfully!") + except subprocess.CalledProcessError as e: + print(f"Error running build_aclnn.sh: {e}") + raise SystemExit(e.returncode) class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. @@ -290,7 +308,9 @@ def target_name(s: str) -> str: print(f"Copy: {src_path} -> {dst_path}") def run(self): - # First, run the standard build_ext command to compile the extensions + # First, ensure ACLNN custom-ops is built and installed. + # self.run_command("build_aclnn") + # Then, run the standard build_ext command to compile the extensions super().run() @@ -353,6 +373,7 @@ def _read_requirements(filename: str) -> List[str]: cmdclass = { "build_py": custom_build_info, + "build_aclnn": build_and_install_aclnn, "build_ext": cmake_build_ext, "install": custom_install } From bb712ab93e7e9268fa219f01bf47cd986eb927b4 Mon Sep 17 00:00:00 2001 From: GuoRen868 <1269192170@qq.com> Date: Mon, 10 Nov 2025 20:23:29 +0800 Subject: [PATCH 5/7] fixbug --- csrc/torch_binding.cpp | 27 +++++++++++++++------------ csrc/torch_binding_meta.cpp | 10 +++++----- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 180044eeb88..2800302a342 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -527,12 +527,12 @@ std::tuple fused_deep_moe(const at::Tensor &x, const at: const at::Tensor &gmm1_permuted_weight, const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, + const at::Tensor &expert_smooth_scales_optional, const at::Tensor &expert_scales_optional, - c10::optional hcom_ep_name, - int64_t num_ranks, int64_t rank, + c10::string_view hcom_ep_name, + int64_t num_ranks, int64_t rank, int64_t moe_expert_num, int64_t shared_expert_num, int64_t shared_expert_rank_num, - int64_t num_experts, int64_t global_bs, - int64_t quant_mode) + int64_t quant_mode, int64_t global_bs) { auto x_shape = x.sizes(); auto experts_shape = expert_ids.sizes(); @@ -542,15 +542,18 @@ std::tuple fused_deep_moe(const at::Tensor &x, const at: at::Tensor output = at::empty({bs, h}, x.options()); bool is_shared_expert = (rank < shared_expert_rank_num); - int64_t num_local_experts = is_shared_expert ? 1 : num_experts / (num_ranks - shared_expert_rank_num); + int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num); at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options()); + vector group_ep_chrs(hcom_ep_name.begin(), hcom_ep_name.end()); + group_ep_chrs.push_back('\0'); + char *group_ep_ptr = &group_ep_chrs[0]; EXEC_NPU_CMD(aclnnFusedDeepMoe, // input x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, - gmm2_weight_scale, static_cast(nullptr), expert_scales_optional, + gmm2_weight_scale, expert_smooth_scales_optional, expert_scales_optional, //attr - hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quant_mode, + group_ep_ptr, num_ranks, rank, moe_expert_num, shared_expert_num, shared_expert_rank_num, quant_mode, global_bs, // output output, ep_recv_count); @@ -619,12 +622,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "fused_deep_moe(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," " Tensor gmm1_permuted_weight_scale," " Tensor gmm2_weight, Tensor gmm2_weight_scale," - " Tensor expert_scales_optional," - " str? hcom_ep_name," - " int num_ranks, int rank," + " Tensor expert_smooth_scales_optional, Tensor expert_scales_optional," + " str hcom_ep_name," + " int num_ranks, int rank, int moe_expert_num," " int shared_expert_num, int shared_expert_rank_num," - " int num_experts, int global_bs," - " int quant_mode) -> (Tensor output, Tensor ep_recv_count)" + " int quant_mode," + " int global_bs) -> (Tensor output, Tensor ep_recv_count)" ); ops.impl("fused_deep_moe", torch::kPrivateUse1, &vllm_ascend::fused_deep_moe); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 0ba450ec3a9..4e76d53a956 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -73,12 +73,12 @@ std::tuple fused_deep_moe_meta(const at::Tensor &x, cons const at::Tensor &gmm1_permuted_weight, const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, + const at::Tensor &expert_smooth_scales_optional, const at::Tensor &expert_scales_optional, - c10::optional hcom_ep_name, - int64_t num_ranks, int64_t rank, + c10::string_view hcom_ep_name, + int64_t num_ranks, int64_t rank, int64_t moe_expert_num, int64_t shared_expert_num, int64_t shared_expert_rank_num, - int64_t num_experts, int64_t global_bs, - int64_t quant_mode) + int64_t quant_mode, int64_t global_bs) { auto x_shape = x.sizes(); auto experts_shape = expert_ids.sizes(); @@ -88,7 +88,7 @@ std::tuple fused_deep_moe_meta(const at::Tensor &x, cons at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta)); bool is_shared_expert = (rank < shared_expert_rank_num); - int64_t num_local_experts = is_shared_expert ? 1 : num_experts / (num_ranks - shared_expert_rank_num); + int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (num_ranks - shared_expert_rank_num); at::Tensor ep_recv_count = at::empty({num_local_experts * num_ranks}, expert_ids.options().device(at::kMeta)); return {output, ep_recv_count}; From e0b3da3ea735abd7871663cacad3a0c548db6a4c Mon Sep 17 00:00:00 2001 From: GuoRen868 <1269192170@qq.com> Date: Wed, 12 Nov 2025 11:57:21 +0800 Subject: [PATCH 6/7] change name --- csrc/build_aclnn.sh | 4 +-- .../op_host/dispatch_gmm_combine_decode.cpp} | 10 +++--- .../dispatch_gmm_combine_decode_infer.cpp} | 6 ++-- .../dispatch_gmm_combine_decode_tiling.cpp} | 36 +++++++++---------- .../dispatch_gmm_combine_decode.cpp} | 12 +++---- .../op_kernel/dispatch_gmm_combine_decode.h} | 34 +++++++++--------- .../dispatch_gmm_combine_decode_base.h} | 6 ++-- .../dispatch_gmm_combine_decode_tiling.h} | 16 ++++----- .../op_kernel/moe_distribute_base.h | 0 ... => aclnn_dispatch_gmm_combine_decode.cpp} | 12 +++---- ....h => aclnn_dispatch_gmm_combine_decode.h} | 8 ++--- .../op_kernel/a3/cam_moe_distribute_combine.h | 8 ++--- .../a3/cam_moe_distribute_dispatch.h | 8 ++--- ...equant_swiglu_quant_multistage_workspace.h | 6 ++-- csrc/custom_ops/scripts/build.sh | 0 .../custom_ops/scripts/compile_ascend_proj.sh | 4 +-- csrc/pytorch_npu_helper.hpp | 2 +- csrc/torch_binding.cpp | 8 ++--- csrc/torch_binding_meta.cpp | 6 ++-- setup.py | 25 ++----------- 20 files changed, 95 insertions(+), 116 deletions(-) rename csrc/custom_ops/kernels/{fused_deep_moe/op_host/fused_deep_moe.cpp => dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode.cpp} (91%) rename csrc/custom_ops/kernels/{fused_deep_moe/op_host/fused_deep_moe_infer.cpp => dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_infer.cpp} (94%) rename csrc/custom_ops/kernels/{fused_deep_moe/op_host/fused_deep_moe_tiling.cpp => dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp} (91%) rename csrc/custom_ops/kernels/{fused_deep_moe/op_kernel/fused_deep_moe.cpp => dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp} (69%) rename csrc/custom_ops/kernels/{fused_deep_moe/op_kernel/fused_deep_moe.h => dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h} (94%) rename csrc/custom_ops/kernels/{fused_deep_moe/op_kernel/fused_deep_moe_base.h => dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h} (80%) rename csrc/custom_ops/kernels/{fused_deep_moe/op_kernel/fused_deep_moe_tiling.h => dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h} (82%) rename csrc/custom_ops/kernels/{fused_deep_moe => dispatch_gmm_combine_decode}/op_kernel/moe_distribute_base.h (100%) rename csrc/custom_ops/kernels/pregen/aclnn/{aclnn_fused_deep_moe.cpp => aclnn_dispatch_gmm_combine_decode.cpp} (78%) rename csrc/custom_ops/kernels/pregen/aclnn/{aclnn_fused_deep_moe.h => aclnn_dispatch_gmm_combine_decode.h} (74%) mode change 100644 => 100755 csrc/custom_ops/scripts/build.sh diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 02a8946da08..118d55a7668 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -5,5 +5,5 @@ cd custom_ops/ bash build.sh custom_ops -cascend910_93 # install custom ops -# ./output/CANN-custom_ops--linux.x86_64.run -# export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} +./build_out/custom_ops/run/CANN_ascend910_93_ubuntu_aarch64.run --install-path=/usr/local/Ascend/ascend-toolkit/latest/opp/ +source /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/bin/set_env.bash diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe.cpp b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode.cpp similarity index 91% rename from csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe.cpp rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode.cpp index 50cd95f9ce0..197d6120af3 100644 --- a/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe.cpp +++ b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode.cpp @@ -1,18 +1,18 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: FusedDeepMoe operator definition file + * Description: DispatchGmmCombineDecode operator definition file * Author: WANG Qiankun * Create: 2025-07-19 * Note: - * History: 2025-07-19 create FusedDeepMoe operator definition file + * History: 2025-07-19 create DispatchGmmCombineDecode operator definition file */ #include "register/op_def_registry.h" namespace ops { -class FusedDeepMoe : public OpDef +class DispatchGmmCombineDecode : public OpDef { public: - explicit FusedDeepMoe(const char *name) : OpDef(name) + explicit DispatchGmmCombineDecode(const char *name) : OpDef(name) { this->Input("x") .ParamType(REQUIRED) @@ -78,5 +78,5 @@ class FusedDeepMoe : public OpDef } }; -OP_ADD(FusedDeepMoe); +OP_ADD(DispatchGmmCombineDecode); } // namespace ops diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_infer.cpp b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_infer.cpp similarity index 94% rename from csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_infer.cpp rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_infer.cpp index 1391b054393..07c84ddb260 100644 --- a/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_infer.cpp +++ b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_infer.cpp @@ -1,10 +1,10 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: FusedDeepMoe tiling function implementation file + * Description: DispatchGmmCombineDecode tiling function implementation file * Author: Guo Ren * Create: 2025-07-22 * Note: - * History: 2025-07-13 create FusedDeepMoe infer function file + * History: 2025-07-13 create DispatchGmmCombineDecode infer function file */ #include @@ -89,5 +89,5 @@ static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) return ge::GRAPH_SUCCESS; } -IMPL_OP(FusedDeepMoe).InferShape(InferShape).InferDataType(InferDataType); +IMPL_OP(DispatchGmmCombineDecode).InferShape(InferShape).InferDataType(InferDataType); } // namespace ge diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp similarity index 91% rename from csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp index c52ff4976a5..062bd05dd51 100644 --- a/csrc/custom_ops/kernels/fused_deep_moe/op_host/fused_deep_moe_tiling.cpp +++ b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -1,10 +1,10 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: FusedDeepMoe tiling function implementation file + * Description: DispatchGmmCombineDecode tiling function implementation file * Author: WANG Qiankun * Create: 2025-07-19 * Note: - * History: 2025-07-19 create FusedDeepMoe tiling function implementation file + * History: 2025-07-19 create DispatchGmmCombineDecode tiling function implementation file */ #include #include @@ -13,7 +13,7 @@ #include "error_log.h" #include "graph/utils/type_utils.h" #include "register/op_def_registry.h" -#include "../op_kernel/fused_deep_moe_tiling.h" +#include "../op_kernel/dispatch_gmm_combine_decode_tiling.h" #include "tiling/platform/platform_ascendc.h" #include "tiling/hccl/hccl_tiling.h" @@ -67,7 +67,7 @@ static size_t CeilUp(size_t x, size_t y) } static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName, - FusedDeepMoeTilingData &tilingData) + DispatchGmmCombineDecodeTilingData &tilingData) { uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; @@ -127,7 +127,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char return ge::GRAPH_SUCCESS; } -static ge::graphStatus CheckData(const char *nodeName, FusedDeepMoeTilingData &tilingData) +static ge::graphStatus CheckData(const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData) { uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; OP_TILING_CHECK(batchSize < MIN_BATCH_SIZE, OP_LOGE(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE), @@ -162,7 +162,7 @@ static ge::graphStatus CheckData(const char *nodeName, FusedDeepMoeTilingData &t } static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, - FusedDeepMoeTilingData &tilingData, std::string &groupEp) + DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp) { auto attrs = context->GetAttrs(); OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); @@ -209,10 +209,10 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con return ge::GRAPH_SUCCESS; } -static void SetHcommCfg(const gert::TilingContext *context, FusedDeepMoeTilingData *tiling, const std::string groupEp) +static void SetHcommCfg(const gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tiling, const std::string groupEp) { const char *nodeName = context->GetNodeName(); - OP_LOGD(nodeName, "FusedDeepMoe groupEp = %s", groupEp.c_str()); + OP_LOGD(nodeName, "DispatchGmmCombineDecode groupEp = %s", groupEp.c_str()); uint32_t opType = OP_TYPE_ALL_TO_ALL; std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; std::string algConfigAllGatherStr = "AllGather=level0:ring"; @@ -223,7 +223,7 @@ static void SetHcommCfg(const gert::TilingContext *context, FusedDeepMoeTilingDa } static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName, - FusedDeepMoeTilingData &tilingData) + DispatchGmmCombineDecodeTilingData &tilingData) { size_t *workSpaces = context->GetWorkspaceSizes(1); OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); @@ -263,10 +263,10 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no return ge::GRAPH_SUCCESS; } -static ge::graphStatus FusedDeepMoeTilingFuncImpl(gert::TilingContext *context) +static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContext *context) { const char *nodeName = context->GetNodeName(); - FusedDeepMoeTilingData *tilingData = context->GetTilingData(); + DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData(); OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); std::string groupEp = ""; @@ -312,20 +312,20 @@ static ge::graphStatus FusedDeepMoeTilingFuncImpl(gert::TilingContext *context) return ge::GRAPH_SUCCESS; } -static ge::graphStatus FusedDeepMoeTilingFunc(gert::TilingContext *context) +static ge::graphStatus DispatchGmmCombineDecodeTilingFunc(gert::TilingContext *context) { - ge::graphStatus ret = FusedDeepMoeTilingFuncImpl(context); + ge::graphStatus ret = DispatchGmmCombineDecodeTilingFuncImpl(context); return ret; } -struct FusedDeepMoeCompileInfo {}; -ge::graphStatus TilingParseForFusedDeepMoe(gert::TilingParseContext *context) +struct DispatchGmmCombineDecodeCompileInfo {}; +ge::graphStatus TilingParseForDispatchGmmCombineDecode(gert::TilingParseContext *context) { (void)context; return ge::GRAPH_SUCCESS; } -IMPL_OP_OPTILING(FusedDeepMoe) - .Tiling(FusedDeepMoeTilingFunc) - .TilingParse(TilingParseForFusedDeepMoe); +IMPL_OP_OPTILING(DispatchGmmCombineDecode) + .Tiling(DispatchGmmCombineDecodeTilingFunc) + .TilingParse(TilingParseForDispatchGmmCombineDecode); } // namespace optiling diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.cpp b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp similarity index 69% rename from csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.cpp rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index 8d25ddb6d44..89f88e4d6d0 100644 --- a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.cpp +++ b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -1,16 +1,16 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: FusedDeepMoe operator kernel function implementation file + * Description: DispatchGmmCombineDecode operator kernel function implementation file * Author: WANG Qiankun * Create: 2025-07-19 * Note: - * History: 2025-07-19 create FusedDeepMoe operator kernel function implementation file + * History: 2025-07-19 create DispatchGmmCombineDecode operator kernel function implementation file */ -#include "fused_deep_moe.h" +#include "dispatch_gmm_combine_decode.h" #include #include "lib/matmul_intf.h" -extern "C" __global__ __aicore__ void fused_deep_moe( +extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( // input GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, @@ -21,11 +21,11 @@ extern "C" __global__ __aicore__ void fused_deep_moe( { icache_preload(8); // New output recvCount - REGISTER_TILING_DEFAULT(FusedDeepMoeTilingData); + REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V GET_TILING_DATA(tiling_data, tiling); if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) { - FusedDeepMoe op; + DispatchGmmCombineDecode op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data); op.Process(); diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.h b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h similarity index 94% rename from csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.h rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index 03ea1800a3b..40f0d35cbca 100644 --- a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe.h +++ b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -1,13 +1,13 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: FusedDeepMoe operator kernel function header file, for a3 + * Description: DispatchGmmCombineDecode operator kernel function header file, for a3 * Author: WANG Qiankun * Create: 2025-07-19 * Note: - * History: 2025-07-19 create FusedDeepMoe operator kernel function header file, for a3 + * History: 2025-07-19 create DispatchGmmCombineDecode operator kernel function header file, for a3 */ -#ifndef FUSED_DEEP_MOE_H -#define FUSED_DEEP_MOE_H +#ifndef DISPATCH_GMM_COMBINE_DECODE_H +#define DISPATCH_GMM_COMBINE_DECODE_H #include "lib/matmul_intf.h" #include @@ -29,8 +29,8 @@ #include "operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h" -#include "fused_deep_moe_tiling.h" -#include "fused_deep_moe_base.h" +#include "dispatch_gmm_combine_decode_tiling.h" +#include "dispatch_gmm_combine_decode_base.h" #define ENABLE_GMM2_COMBINE @@ -235,10 +235,10 @@ ACT_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGr } template -class FusedDeepMoe +class DispatchGmmCombineDecode { public: - __aicore__ inline FusedDeepMoe(){}; + __aicore__ inline DispatchGmmCombineDecode(){}; __aicore__ inline void Init( // input GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, @@ -246,7 +246,7 @@ class FusedDeepMoe // output GM_ADDR output, GM_ADDR outputRecvCount, // system - GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData); + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); __aicore__ inline void Process(); private: @@ -285,18 +285,18 @@ class FusedDeepMoe AscendC::TPipe *tpipe_{nullptr}; __gm__ HcclOpResParam *winContext_{nullptr}; - const FusedDeepMoeTilingData *tilingData_; + const DispatchGmmCombineDecodeTilingData *tilingData_; }; template -__aicore__ inline void FusedDeepMoe::Init( +__aicore__ inline void DispatchGmmCombineDecode::Init( // input GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, // output GM_ADDR output, GM_ADDR outputRecvCount, // system - GM_ADDR workspaceGM, AscendC::TPipe *pipe, const FusedDeepMoeTilingData *tilingData) + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) { tpipe_ = pipe; blockDim_ = AscendC::GetBlockNum(); @@ -341,15 +341,15 @@ __aicore__ inline void FusedDeepMoe::Init( } template -__aicore__ inline void FusedDeepMoe::Process() +__aicore__ inline void DispatchGmmCombineDecode::Process() { #ifdef ENABLE_GMM2_COMBINE if (g_coreType == AscendC::AIV) { - ((FusedDeepMoeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = get_block_num(); + ((DispatchGmmCombineDecodeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = get_block_num(); if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - ((FusedDeepMoeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = get_block_num(); + ((DispatchGmmCombineDecodeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = get_block_num(); } else { - ((FusedDeepMoeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = + ((DispatchGmmCombineDecodeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = get_block_num() * get_subblockdim(); } } @@ -444,4 +444,4 @@ __aicore__ inline void FusedDeepMoe::Process() layoutOutput, gmWorkspace, &combiner); #endif } -#endif // FUSED_DEEP_MOE_H +#endif // DISPATCH_GMM_COMBINE_DECODE_H diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_base.h b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h similarity index 80% rename from csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_base.h rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h index d09d4894adb..1d59fc58d25 100644 --- a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_base.h +++ b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -6,12 +6,12 @@ * Note: * History: 2025-07-19 Create a definition file for a distribution group related structure */ -#ifndef FUSED_DEEP_MOE_BASE_H -#define FUSED_DEEP_MOE_BASE_H +#ifndef DISPATCH_GMM_COMBINE_DECODE_BASE_H +#define DISPATCH_GMM_COMBINE_DECODE_BASE_H #include "moe_distribute_base.h" #define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG #define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG -#endif // FUSED_DEEP_MOE_BASE_H +#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h similarity index 82% rename from csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h index a4899debceb..55aae88a659 100644 --- a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/fused_deep_moe_tiling.h +++ b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h @@ -1,18 +1,18 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: FusedDeepMoe tilingData definition file + * Description: DispatchGmmCombineDecode tilingData definition file * Author: WANG Qiankun * Create: 2025-07-19 * Note: - * History: 2025-07-19 create FusedDeepMoe tilingData definition file + * History: 2025-07-19 create DispatchGmmCombineDecode tilingData definition file */ -#ifndef FUSED_DEEP_MOE_TILING_H -#define FUSED_DEEP_MOE_TILING_H +#ifndef DISPATCH_GMM_COMBINE_DECODE_TILING_H +#define DISPATCH_GMM_COMBINE_DECODE_TILING_H #include "kernel_tiling/kernel_tiling.h" -struct FusedDeepMoeInfo { +struct DispatchGmmCombineDecodeInfo { uint32_t epRankSize; // epRankSize uint32_t epRankId; // epRankId uint32_t moeExpertNum; // moe expert number @@ -31,10 +31,10 @@ struct FusedDeepMoeInfo { uint64_t gmm1HLen; }; -struct FusedDeepMoeTilingData { +struct DispatchGmmCombineDecodeTilingData { Mc2InitTiling mc2InitTiling; Mc2CcTiling mc2CcTiling; - FusedDeepMoeInfo disGmmDeqSwigluQuantGmmDeqComInfo; + DispatchGmmCombineDecodeInfo disGmmDeqSwigluQuantGmmDeqComInfo; }; constexpr uint32_t GM_ALIGN_BYTE = 512; @@ -70,4 +70,4 @@ constexpr uint32_t WORKSPACE_STAGES = 4; constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); -#endif // FUSED_DEEP_MOE_TILING_H +#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H diff --git a/csrc/custom_ops/kernels/fused_deep_moe/op_kernel/moe_distribute_base.h b/csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/moe_distribute_base.h similarity index 100% rename from csrc/custom_ops/kernels/fused_deep_moe/op_kernel/moe_distribute_base.h rename to csrc/custom_ops/kernels/dispatch_gmm_combine_decode/op_kernel/moe_distribute_base.h diff --git a/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.cpp b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_dispatch_gmm_combine_decode.cpp similarity index 78% rename from csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.cpp rename to csrc/custom_ops/kernels/pregen/aclnn/aclnn_dispatch_gmm_combine_decode.cpp index be88b904544..2a0ea4c02b0 100644 --- a/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.cpp +++ b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_dispatch_gmm_combine_decode.cpp @@ -1,8 +1,8 @@ #include #include "graph/types.h" #include "aclnn/opdev/platform.h" -#include "aclnn_fused_deep_moe.h" -#include "aclnnInner_fused_deep_moe.h" +#include "aclnn_dispatch_gmm_combine_decode.h" +#include "aclnnInner_dispatch_gmm_combine_decode.h" enum NnopbaseHcclServerType { NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, @@ -15,7 +15,7 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, extern "C" { #endif -aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize( +aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, const aclTensor *gmm1PermutedWeight, @@ -37,13 +37,13 @@ aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize( uint64_t *workspaceSize, aclOpExecutor **executor) { - return aclnnInnerFusedDeepMoeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, + return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs, output, epRecvCount, workspaceSize, executor); } -aclnnStatus aclnnFusedDeepMoe( +aclnnStatus aclnnDispatchGmmCombineDecode( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, @@ -56,7 +56,7 @@ aclnnStatus aclnnFusedDeepMoe( NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); } } - return aclnnInnerFusedDeepMoe(workspace, workspaceSize, executor, stream); + return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream); } #ifdef __cplusplus diff --git a/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.h b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_dispatch_gmm_combine_decode.h similarity index 74% rename from csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.h rename to csrc/custom_ops/kernels/pregen/aclnn/aclnn_dispatch_gmm_combine_decode.h index 435ec98e50a..6b916d1d0c5 100644 --- a/csrc/custom_ops/kernels/pregen/aclnn/aclnn_fused_deep_moe.h +++ b/csrc/custom_ops/kernels/pregen/aclnn/aclnn_dispatch_gmm_combine_decode.h @@ -1,5 +1,5 @@ -#ifndef FUSED_DEEP_MOE -#define FUSED_DEEP_MOE +#ifndef DISPATCH_GMM_COMBINE_DECODE +#define DISPATCH_GMM_COMBINE_DECODE #include "aclnn/acl_meta.h" @@ -7,7 +7,7 @@ extern "C" { #endif -__attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoeGetWorkspaceSize( +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( const aclTensor *x, const aclTensor *expertIds, const aclTensor *gmm1PermutedWeight, @@ -29,7 +29,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoeGetWorkspace uint64_t *workspaceSize, aclOpExecutor **executor); -__attribute__((visibility("default"))) aclnnStatus aclnnFusedDeepMoe( +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, diff --git a/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h index 9ccb5867370..1898fcdae47 100644 --- a/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h +++ b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h @@ -12,8 +12,8 @@ #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" -#include "../../../../fused_deep_moe_base.h" -#include "../../../../fused_deep_moe_tiling.h" +#include "../../../../dispatch_gmm_combine_decode_base.h" +#include "../../../../dispatch_gmm_combine_decode_tiling.h" namespace MoeDistributeCombineImpl { constexpr uint8_t BUFFER_NUM = 2; // multi-buf @@ -62,7 +62,7 @@ class CamMoeDistributeCombine __aicore__ inline CamMoeDistributeCombine(){}; __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, - const FusedDeepMoeTilingData *tilingData); + const DispatchGmmCombineDecodeTilingData *tilingData); __aicore__ inline void Process(); __aicore__ inline void AllToAllSend(); __aicore__ inline void ReducePermute(); @@ -231,7 +231,7 @@ class CamMoeDistributeCombine template __aicore__ inline void CamMoeDistributeCombine::Init( GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, - GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const FusedDeepMoeTilingData *tilingData) + GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) { tpipe_ = pipe; coreIdx_ = GetBlockIdx(); diff --git a/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h index a2f1febee9b..ac4140894af 100644 --- a/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h +++ b/csrc/custom_ops/kernels/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h @@ -13,8 +13,8 @@ #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" -#include "../../../../fused_deep_moe_base.h" -#include "../../../../fused_deep_moe_tiling.h" +#include "../../../../dispatch_gmm_combine_decode_base.h" +#include "../../../../dispatch_gmm_combine_decode_tiling.h" namespace MoeDistributeDispatchImpl { constexpr uint8_t BUFFER_NUM = 2; // 多buf @@ -63,7 +63,7 @@ class CamMoeDistributeDispatch __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, - GM_ADDR workspaceGM, TPipe *pipe, const FusedDeepMoeTilingData *tilingData); + GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); __aicore__ inline void Process(); private: @@ -221,7 +221,7 @@ template __aicore__ inline void CamMoeDistributeDispatch::Init( GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, - GM_ADDR workspaceGM, TPipe *pipe, const FusedDeepMoeTilingData *tilingData) + GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) { tpipe_ = pipe; aivId_ = GetBlockIdx(); diff --git a/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index 32f42f94255..18d30eb79c3 100644 --- a/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/custom_ops/kernels/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -1,10 +1,10 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * Description: FusedDeepMoe operator kernel function implementation file + * Description: DispatchGmmCombineDecode operator kernel function implementation file * Author: WANG Qiankun * Create: 2025-07-19 * Note: - * History: 2025-07-19 create FusedDeepMoe operator kernel function implementation file + * History: 2025-07-19 create DispatchGmmCombineDecode operator kernel function implementation file */ #pragma once @@ -18,7 +18,7 @@ #include "../../catlass/act/epilogue/tile/tile_swizzle.hpp" #include "../../catlass/act/epilogue/tile/tile_copy.hpp" -#include "../../../fused_deep_moe_base.h" +#include "../../../dispatch_gmm_combine_decode_base.h" constexpr uint32_t STATE_OFFSET = 512; constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; diff --git a/csrc/custom_ops/scripts/build.sh b/csrc/custom_ops/scripts/build.sh old mode 100644 new mode 100755 diff --git a/csrc/custom_ops/scripts/compile_ascend_proj.sh b/csrc/custom_ops/scripts/compile_ascend_proj.sh index a462ce1f51a..a3af866dc26 100644 --- a/csrc/custom_ops/scripts/compile_ascend_proj.sh +++ b/csrc/custom_ops/scripts/compile_ascend_proj.sh @@ -44,7 +44,7 @@ BuildAscendProj() { rm -rf ./${proj_name}/op_host/add_custom* rm -rf ./${proj_name}/op_kernel/add_custom* CopyOps "./kernels" "./${proj_name}" - python $SCRIPTS_PATH/custom_ops/set_conf.py ./${proj_name}/CMakePresets.json $build_type True CAM + python $SCRIPTS_PATH/set_conf.py ./${proj_name}/CMakePresets.json $build_type True CAM cp -rf ./kernels/pregen ./${proj_name} source $ASCEND_HOME_PATH/bin/setenv.bash @@ -58,7 +58,7 @@ BuildAscendProj() { mkdir ${BUILD_OUT_PATH}/custom_ops/extract/${soc_version} build_out/*.run --extract=${BUILD_OUT_PATH}/custom_ops/extract/${soc_version} else - cp build_out/*.run ${BUILD_OUT_PATH}/custom_ops/run/CAM_${soc_version}_${os_id}_${arch}.run + cp build_out/*.run ${BUILD_OUT_PATH}/custom_ops/run/CANN_${soc_version}_${os_id}_${arch}.run fi } diff --git a/csrc/pytorch_npu_helper.hpp b/csrc/pytorch_npu_helper.hpp index d9d32a650b4..ea627a94008 100644 --- a/csrc/pytorch_npu_helper.hpp +++ b/csrc/pytorch_npu_helper.hpp @@ -220,7 +220,7 @@ inline aclTensor *ConvertType(const at::Tensor &at_tensor) } aclFormat format = ACL_FORMAT_ND; - // 适配fused_deep_moe算子的weight入参 + // 适配dispatch_gmm_combine_decode算子的weight入参 if (acl_data_type == ACL_INT8 && dimNum == 3) { format = ACL_FORMAT_FRACTAL_NZ; } diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 2800302a342..add3a8a1e33 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -523,7 +523,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic return y_out; } -std::tuple fused_deep_moe(const at::Tensor &x, const at::Tensor &expert_ids, +std::tuple dispatch_gmm_combine_decode(const at::Tensor &x, const at::Tensor &expert_ids, const at::Tensor &gmm1_permuted_weight, const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, @@ -548,7 +548,7 @@ std::tuple fused_deep_moe(const at::Tensor &x, const at: vector group_ep_chrs(hcom_ep_name.begin(), hcom_ep_name.end()); group_ep_chrs.push_back('\0'); char *group_ep_ptr = &group_ep_chrs[0]; - EXEC_NPU_CMD(aclnnFusedDeepMoe, + EXEC_NPU_CMD(aclnnDispatchGmmCombineDecode, // input x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, expert_smooth_scales_optional, expert_scales_optional, @@ -619,7 +619,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.def( - "fused_deep_moe(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," + "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," " Tensor gmm1_permuted_weight_scale," " Tensor gmm2_weight, Tensor gmm2_weight_scale," " Tensor expert_smooth_scales_optional, Tensor expert_scales_optional," @@ -630,5 +630,5 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " int global_bs) -> (Tensor output, Tensor ep_recv_count)" ); - ops.impl("fused_deep_moe", torch::kPrivateUse1, &vllm_ascend::fused_deep_moe); + ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 4e76d53a956..9dfd7b989d9 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -69,7 +69,7 @@ std::tuple get_masked_input_and_mask_meta( return {masked_input, mask}; } -std::tuple fused_deep_moe_meta(const at::Tensor &x, const at::Tensor &expert_ids, +std::tuple dispatch_gmm_combine_decode_meta(const at::Tensor &x, const at::Tensor &expert_ids, const at::Tensor &gmm1_permuted_weight, const at::Tensor &gmm1_permuted_weight_scale, const at::Tensor &gmm2_weight, const at::Tensor &gmm2_weight_scale, @@ -157,7 +157,7 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); // MLA preprocess ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess); - // Masked fused_deep_moe_meta meta implementation - ops.impl("fused_deep_moe", &vllm_ascend::meta::fused_deep_moe_meta); + // Masked dispatch_gmm_combine_decode_meta meta implementation + ops.impl("dispatch_gmm_combine_decode", &vllm_ascend::meta::dispatch_gmm_combine_decode_meta); } } diff --git a/setup.py b/setup.py index 2c41bc67459..5a823e7abc5 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ from sysconfig import get_paths from typing import Dict, List -from setuptools import Command, Extension, find_packages, setup +from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py from setuptools.command.develop import develop @@ -102,24 +102,6 @@ def run(self): f"Generated _build_info.py with SOC version: {soc_version}") super().run() -class build_and_install_aclnn(Command): - description = "Build and install AclNN by running build_aclnn.sh" - user_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - try: - print("Running bash build_aclnn.sh ...") - subprocess.check_call(["bash", "csrc/build_aclnn.sh", ROOT_DIR]) - print("buid_aclnn.sh executed successfully!") - except subprocess.CalledProcessError as e: - print(f"Error running build_aclnn.sh: {e}") - raise SystemExit(e.returncode) class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. @@ -308,9 +290,7 @@ def target_name(s: str) -> str: print(f"Copy: {src_path} -> {dst_path}") def run(self): - # First, ensure ACLNN custom-ops is built and installed. - # self.run_command("build_aclnn") - # Then, run the standard build_ext command to compile the extensions + # First, run the standard build_ext command to compile the extensions super().run() @@ -373,7 +353,6 @@ def _read_requirements(filename: str) -> List[str]: cmdclass = { "build_py": custom_build_info, - "build_aclnn": build_and_install_aclnn, "build_ext": cmake_build_ext, "install": custom_install } From 49eca002f16876641d2a79f7cb634ab9096e704f Mon Sep 17 00:00:00 2001 From: GuoRen868 <1269192170@qq.com> Date: Wed, 12 Nov 2025 17:53:33 +0800 Subject: [PATCH 7/7] uttest --- .../ops/test_fused_deep_moe_accuracy.py | 303 ++++++++++++++++++ 1 file changed, 303 insertions(+) create mode 100644 tests/e2e/nightly/ops/test_fused_deep_moe_accuracy.py diff --git a/tests/e2e/nightly/ops/test_fused_deep_moe_accuracy.py b/tests/e2e/nightly/ops/test_fused_deep_moe_accuracy.py new file mode 100644 index 00000000000..05c9c8e8296 --- /dev/null +++ b/tests/e2e/nightly/ops/test_fused_deep_moe_accuracy.py @@ -0,0 +1,303 @@ +import os +import sys +import numpy as np +import torch +import torch_npu +import torch.distributed as dist +import torch.multiprocessing as mp +from pathlib import Path +from vllm_ascend.utils import enable_custom_op + +torch_npu.npu.config.allow_internal_format = True +use_graph = False +test_bfloat16 = True +enable_dynamic_bs = False +if use_graph: + import torchair + from torchair.configs.compiler_config import CompilerConfig + torch_npu.npu.set_compile_mode(jit_compile=True) + config = CompilerConfig() + npu_backend = torchair.get_npu_backend(compiler_config=config) + +enable_custom_op() + +TP = 1 +print(f"{len(sys.argv)= }, {sys.argv= }\n{use_graph= }, {test_bfloat16= }, {enable_dynamic_bs= }") +assert len(sys.argv) == 7, "入参列表:[0]rank_size, [1]share_expert_rank_num, [2]moe_expert_num, [3]bs, [4]name, [5]loop_cnt" +ep_world_size = int(sys.argv[1]) +SHARE_RANK_NUM = int(sys.argv[2]) +MOE_RANK_NUM = ep_world_size - SHARE_RANK_NUM +MOE_EXPERT_NUM = int(sys.argv[3]) +MOE_EXPERT_NUM_PER_RANK = MOE_EXPERT_NUM // MOE_RANK_NUM +RANK_BS = int(sys.argv[4]) +LOG_NAME = str(sys.argv[5]) +loop_times = int(str(sys.argv[6])) +node_num = 1 + +SHARE_EXPERT_NUM = SHARE_RANK_NUM +DISPATCH_QUANT = True +H = 7168 +K = 8 +GMM1_INPUT = H +GMM1_HIDDEN = 4096 +GMM2_INPUT = GMM1_HIDDEN // 2 +GMM2_HIDDEN = H + +global_rank_id = 0 +ep_hcomm_info = None +ep_hcomm_info_small = None +commArgs = None +tp_hcomm_info = None +device_id = None + +def redirect_output(log_file_path): + log_path = Path(LOG_NAME) / log_file_path + log_path.parent.mkdir(parents=True, exist_ok=True) + f = open(LOG_NAME + "/" + log_file_path, "w") + os.dup2(f.fileno(), sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stderr.fileno()) + return f + +def permute_weight(w: torch.Tensor, tile_n): + *dims, n = w.shape + order = list(range(len(dims))) + [-2, -3, -1] + return w.reshape(*dims, 2, n // tile_n, tile_n // 2).permute(order).reshape(*dims, n).contiguous() + +def output_to_file(rank_id): + # return True + return not (rank_id in [0, SHARE_RANK_NUM]) + +class SmallOps(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, smooth_scales, expert_scales): + outputs = torch_npu.npu_moe_distribute_dispatch_v2( + x = x, + expert_ids = expert_ids, + expert_scales = expert_scales, + group_ep = ep_hcomm_info_small, + ep_world_size = ep_world_size, + ep_rank_id = global_rank_id, + moe_expert_num = MOE_EXPERT_NUM, + group_tp = tp_hcomm_info, + tp_world_size = 1, + tp_rank_id = 0, + expert_shard_type = 0, + shared_expert_num = 1, + shared_expert_rank_num = SHARE_RANK_NUM, + quant_mode = 2 if DISPATCH_QUANT else 0, + global_bs = RANK_BS * ep_world_size, + expert_token_nums_type = 1, # 0代表前缀和,1代表各自数量 + ) + expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs + output_dtype = torch.bfloat16 if test_bfloat16 else torch.half + + y1_int32 = torch_npu.npu_grouped_matmul( + x=[expand_x], + weight=[gmm1_weight], + split_item=3, + group_list_type=1, # 默认为0,代表前缀和形式 + group_type=0, # 0代表m轴分组 + group_list=expert_token_nums, + output_dtype=torch.int32)[0] + y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( + x=y1_int32, + weight_scale=gmm1_weight_scale.to(torch.float32), + activation_scale=dynamic_scales, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_token_nums, + activate_left=True, + quant_mode=1, + ) + y2 = torch_npu.npu_grouped_matmul( + x=[y1], + weight=[gmm2_weight], + scale=[gmm2_weight_scale], + per_token_scale=[y1_scale], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_nums, + output_dtype=output_dtype)[0] + combine_output = torch_npu.npu_moe_distribute_combine_v2( + expand_x = y2, + expert_ids = expert_ids, + assist_info_for_combine = assist_info_for_combine, + ep_send_counts = ep_send_counts, + expert_scales = expert_scales, + group_ep = ep_hcomm_info_small, + ep_world_size = ep_world_size, + ep_rank_id = global_rank_id, + moe_expert_num = MOE_EXPERT_NUM, + tp_send_counts = tp_send_counts, + expand_scales = expand_scales, + group_tp = tp_hcomm_info, + tp_world_size = 1, + tp_rank_id = 0, + expert_shard_type = 0, + shared_expert_num = 1, + shared_expert_rank_num = SHARE_RANK_NUM, + global_bs = RANK_BS * ep_world_size + ) + return combine_output + +class FusionOp(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, smooth_scales, expert_scales): + # print(f"{expert_scales=} {expert_scales.dtype=}") + output = torch.ops._C_ascend.dispatch_gmm_combine_decode(x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, + gmm2_weight_scale, smooth_scales, expert_scales, ep_hcomm_info, ep_world_size, global_rank_id, MOE_EXPERT_NUM, + 1, SHARE_RANK_NUM, 0, RANK_BS * ep_world_size) + return output + + +def generate_datas(): + if enable_dynamic_bs: + actual_bs = torch.randint(1, RANK_BS, [1]).item() + print(f"rank-{global_rank_id}: {actual_bs=}") + else: + actual_bs = RANK_BS + local_expert_num = 1 if global_rank_id < SHARE_RANK_NUM else MOE_EXPERT_NUM_PER_RANK + x = torch.rand([actual_bs, H]).half() + x = x * 10 - 5 + expert_ids = [i % MOE_EXPERT_NUM for i in range(global_rank_id * RANK_BS * K, global_rank_id * RANK_BS * K + actual_bs * K)] + # expert_ids = [(i + global_rank_id) % MOE_EXPERT_NUM for i in range(RANK_BS * K)] + expert_ids = torch.Tensor(expert_ids).to(torch.int32).view(actual_bs, K) + if global_rank_id < SHARE_RANK_NUM: + gmm1_weight = torch.ones([local_expert_num, GMM1_INPUT, GMM1_HIDDEN]).to(torch.int8) * 4 + gmm2_weight = torch.ones([local_expert_num, GMM2_INPUT, GMM2_HIDDEN]).to(torch.int8) * 4 + gmm1_weight[:,:,::2] = gmm1_weight[:,:,::2] * -1 + gmm2_weight[:,:,::2] = gmm2_weight[:,:,::2] * -1 + gmm1_weight_scale = torch.ones([local_expert_num, GMM1_HIDDEN]) * 0.0015 + gmm2_weight_scale = torch.ones([local_expert_num, GMM2_HIDDEN]) * 0.0015 + else: + gmm1_weight = torch.randint(-16, 16, [local_expert_num, GMM1_INPUT, GMM1_HIDDEN]).to(torch.int8) + gmm2_weight = torch.randint(-16, 16, [local_expert_num, GMM2_INPUT, GMM2_HIDDEN]).to(torch.int8) + gmm1_weight_scale = torch.rand([local_expert_num, GMM1_HIDDEN]) * 0.003 + 0.0015 + gmm2_weight_scale = torch.rand([local_expert_num, GMM2_HIDDEN]) * 0.003 + 0.0015 + expert_scales = torch.rand(actual_bs, K) + if test_bfloat16: + x = x.bfloat16() + gmm1_weight_scale = gmm1_weight_scale.bfloat16() + gmm2_weight_scale = gmm2_weight_scale.bfloat16() + else: + x = x.half() + return x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, None, expert_scales + +def test_small_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, smooth_sales, expert_scales): + small_op = SmallOps().npu() + # if use_graph: + # small_op = torch.compile(small_op, backend=npu_backend) + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, 29) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, 29) + for _ in range(1, loop_times + 1): + output = small_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, smooth_sales, expert_scales) + return output + +def test_fused_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, smooth_sales, expert_scales): + fused_op = FusionOp().npu() + if use_graph: + fused_op = torch.compile(fused_op, backend=npu_backend) + gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\ + .view(-1, 2, 32, 64, 7168).transpose(1,2).contiguous()\ + .view(-1, 4096, 7168).transpose(1,2).contiguous() + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, 2) + gmm1_weight.add_(0) + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, 29) + + gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight.transpose(1, 2).contiguous(), 29) + + if test_bfloat16: + gmm1_weight_scale = gmm1_weight_scale.float() + gmm2_weight_scale = gmm2_weight_scale.float() + + smooth_sales = torch.ones([RANK_BS]).float().npu() if smooth_sales is None else smooth_sales + for _ in range(1, loop_times + 1): + # print(f"iter: {_} / {loop_times}") + output = fused_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, smooth_sales, expert_scales) + torch_npu.npu.synchronize(device_id) + print(f"fused op run end") + # 只返回一个出参 + return output[0] + +def test(): + tensor_datas = [data.npu() if data is not None else None for data in generate_datas()] + + small_op_datas = [data.clone().detach() if data is not None else None for data in tensor_datas] + small_op_output = test_small_op(*small_op_datas) + print(f"{small_op_output= }\n {small_op_output.abs().mean()=}, {small_op_output.abs().max()=}") + + fused_op_datas = [data.clone().detach() if data is not None else None for data in tensor_datas] + fused_op_output = test_fused_op(*fused_op_datas) + print(f"{fused_op_output= }\n {fused_op_output.abs().mean()=}, {fused_op_output.abs().max()=}") + + diff = (small_op_output - fused_op_output).abs() + print(f"[info-{global_rank_id}] fused deep moe: {diff.max()= }, {diff.mean()= }") + +def test_diff_data(): + diff_test_time = 10 + for test_time in range(diff_test_time): + tensor_datas = [data.npu() if data is not None else None for data in generate_datas()] + tensor_datas[1] = (tensor_datas[1] + test_time * 3) % MOE_EXPERT_NUM + print(f"{tensor_datas[1]=}") + + small_op_datas = [data.clone().detach() if data is not None else None for data in tensor_datas] + small_op_output = test_small_op(*small_op_datas) + # print(f"{small_op_output= }\n {small_op_output.abs().mean()=}, {small_op_output.abs().max()=}") + + fused_op_datas = [data.clone().detach() if data is not None else None for data in tensor_datas] + fused_op_output = test_fused_op(*fused_op_datas) + # print(f"{fused_op_output= }\n {fused_op_output.abs().mean()=}, {fused_op_output.abs().max()=}") + + # small_op_datas = [data.clone().detach() if data is not None else None for data in tensor_datas] + # small_op_output = test_small_op(*small_op_datas) + # # print(f"{small_op_output= }\n {small_op_output.abs().mean()=}, {small_op_output.abs().max()=}") + + diff = (small_op_output - fused_op_output).abs() + error_max = diff.max().item() > 1.0 + error_mean = diff.mean().item() > 1.0 + print(f"[info-{global_rank_id}] test:{test_time+1}/{diff_test_time}, {diff.max()= }, {diff.mean()= }, {error_max=}, {error_mean= }") + +def worker(rank, ep_world_size): + if output_to_file(rank): + log_file = redirect_output(f"log_test_accuracy_rank_{rank}.txt") + global global_rank_id, ep_hcomm_info, ep_hcomm_info_small, tp_hcomm_info, device_id + global_rank_id = rank + device_id = rank % 16 + torch_npu.npu.set_device(device_id) + + # 1. 初始化分布式环境 + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" # 端口号随意 + dist.init_process_group(backend="hccl", rank=rank, world_size=ep_world_size) + + print(f"[info-{rank}] start ep comm init...") + ep_ranks_list = list(np.arange(0, ep_world_size)) + print(f"[info-{rank}] ep rank list:", ep_ranks_list) + ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list) + ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list) + tp_group = dist.new_group(backend="hccl", ranks=[rank]) + + ep_hcomm_info = ep_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) + ep_hcomm_info_small = ep_group_small._get_backend(torch.device("npu")).get_hccl_comm_name(rank) + tp_hcomm_info = tp_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) + + torch_npu.npu.synchronize(device_id) + print(f"[info-{rank}] ep group: {ep_group}, ep_hcomm_info:{type(ep_hcomm_info)}") + + test() + # test_diff_data() + # # # 5. 关闭进程组 + torch_npu.npu.synchronize(device_id) + dist.destroy_process_group() + if output_to_file(rank): + log_file.close() + +if __name__ == "__main__": + mp.spawn(worker, args=(ep_world_size,), nprocs=ep_world_size, join=True)