Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ option(BUILD_TESTS "build test or not" OFF)

option(BUILD_DEEPEP_MODULE "build deepep" ON)
option(BUILD_KERNELS_MODULE "build kernels" ON)
option(BUILD_CATLASS_MODULE "build catlass ops within kernels" OFF)

set(CMAKE_CXX_STANDARD 17)
#set(CMAKE_VERBOSE_MAKEFILE ON)
Expand All @@ -17,6 +18,12 @@ endif()
add_compile_options(-hno-unused-parameter -lno-unused-function -Wunused-value -Wcast-align)
add_compile_options(-Wcast-qual -Winvalid-pch -Wwrite-strings -Wsign-compare -Wextra)

if (BUILD_CATLASS_MODULE)
add_compile_definitions(BUILD_CATLASS_MODULE)
set(CATLASS_DIR "${PROJECT_SOURCE_DIR}/3rdparty/catlass") # specific your catlass path here
message(STATUS "[CATLASS] ${CATLASS_DIR}")
endif ()

if (${CMAKE_BUILD_TYPE} MATCHES "RELEASE")
add_compile_options(-O3)
add_compile_options(-fvisibility=hidden -fvisibility-inlines-hidden)
Expand Down
19 changes: 18 additions & 1 deletion csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ FILE(GLOB OP_SRCS
${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_expand.cpp
${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_shrink.cpp
)
if(BUILD_CATLASS_MODULE)
list(APPEND OP_SRCS
${PROJECT_OP_SRC_BASE}/catlass/op_host/catlass_matmul_basic.cpp
)
endif()

# set the so name
set(OP_PLUGIN_NAME sgl_kernel_npu)
Expand All @@ -34,11 +39,23 @@ ascendc_library(no_workspace_kernel STATIC
${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgmv_shrink_kernel.cpp
)

ascendc_library(workspace_kernel STATIC
# kernel side files with workspace
set(WORKSPACE_KERNEL_SRCS
${PROJECT_OP_SRC_BASE}/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp
${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp
)
if(BUILD_CATLASS_MODULE)
list(APPEND WORKSPACE_KERNEL_SRCS
${PROJECT_OP_SRC_BASE}/catlass/op_kernel/catlass_matmul_basic_kernel.cpp
)
endif()
ascendc_library(workspace_kernel STATIC ${WORKSPACE_KERNEL_SRCS})
if(BUILD_CATLASS_MODULE)
ascendc_include_directories(workspace_kernel PRIVATE
${CATLASS_DIR}/include
)
endif()

ascendc_compile_definitions(workspace_kernel PRIVATE
-DHAVE_WORKSPACE
Expand Down
78 changes: 78 additions & 0 deletions csrc/catlass/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# torch.ops.catlass_matmul_basic


## Function Description | 功能描述

### English:
This is the catlass version matmul kernel `catlass_matmul_basic` kernel function

### 中文:
这是调用catlass模板库实现的矩阵乘法运算 `catlass_matmul_basic`内核方法

参考/Refs: [CATLSS](https://gitcode.com/cann/catlass)


## Interface Prototype | 接口原型

### Python Binding Definition
```python
import sgl_kernel_npu

torch.ops.npu.catlass_matmul_basic(
input_a: torch.Tensor, # bf16/fp16/fp32, [m, k]
input_b: torch.Tensor, # bf16/fp16/fp32, [k, n]
output_c: torch.Tensor, # bf16/fp16/fp32, [m, n]
format_mode: str = None # string "ND"/"NZ"
) -> None
```

### Kernel Definition | 核函数定义
```C++
extern "C" __global__ __aicore__ void catlass_matmul_basic(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC,
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
```

## Parameter Description | 参数说明

| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 |
|:----------------------|:----------------|:-------------------------------------|:--------------|
| `input_a` | `torch.Tensor` | input left tensor with shape (m, k) | 左输入矩阵,(m,k)大小 |
| `input_b` | `torch.Tensor` | input right tensor with shape (k, n) | 右输入矩阵,(k,n)大小 |
| `format_mode` | `[optional]string` | weight format ND/NZ, default ND | 权重格式ND/NZ, 默认为 ND


## Output Description | 输出说明

| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 |
|:-----------------------|:----------------|:--------------------------------|:--------------|
| `output_c` | `torch.Tensor` | output tensor with shape (m, n) | 输出矩阵,(m, n)大小 |


## Constraints | 约束说明

### English:
`format_mode = "NZ"` is not implemented

### 中文:
`format_mode = "NZ"` 暂未实现

## Example | 调用示例

```python
import math
import sgl_kernel_npu
import torch
import torch_npu

device = torch.device('npu:0')

dtypes = [torch.float16, torch.bfloat16, torch.float32]
m, k, n = 128, 256, 256

for dtype in dtypes:
a = torch.rand(m, k, dtype=dtype, device="npu")
b = torch.rand(k, n, dtype=dtype, device="npu")
res = torch.empty((m, n), dtype=dtype, device="npu")

torch.ops.npu.catlass_matmul_basic(a, b, res)
```
93 changes: 93 additions & 0 deletions csrc/catlass/op_host/catlass_matmul_basic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>

#include "defines.h"
#include "tiling/platform/platform_ascendc.h"
#include "torch_helper.h"
#include "catlass_matmul_tiling.h"
#include "aclrtlaunch_catlass_matmul_basic.h"

namespace sglang {
namespace npu_kernel {

constexpr uint32_t PADDING_BYTE = 32U;

std::map<c10::ScalarType, DataFormatMode> dTypeMap = {{at::ScalarType::Half, DataFormatMode::FP16},
{at::ScalarType::BFloat16, DataFormatMode::BF16},
{at::ScalarType::Float, DataFormatMode::FP32}};

std::unordered_map<c10::string_view, uint16_t> weightFormatMap = {{"ND", WeightFormatMode::WEIGHT_ND},
{"NZ", WeightFormatMode::WEIGHT_NZ}};

template <typename MapType>
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
const char *mode_name)
{
std::string modeStr(mode_name);
c10::string_view mode_str = mode_opt.value_or(default_mode);
auto it = mode_map.find(mode_str);
// if input mode is unsupported, use default value
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
return it->second;
}

at::Tensor get_tiling(int32_t &m, int32_t &n, int32_t k, int64_t weight_format_mode, int64_t data_format_mode,
uint32_t &blockDim)
{
auto ascendc_platform = platform_ascendc::PlatformAscendCManager::GetInstance();
blockDim = static_cast<uint32_t>(ascendc_platform->GetCoreNumAiv());

// align to 32 bytes
int32_t tiling_size = (sizeof(KernelCatlassMatmulTilingData) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;
auto tiling_buffer = at::empty({tiling_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU));

KernelCatlassMatmulTilingData *tiling_data =
reinterpret_cast<KernelCatlassMatmulTilingData *>(tiling_buffer.data_ptr());
tiling_data->m = m;
tiling_data->n = n;
tiling_data->k = k;
tiling_data->weight_format_mode = weight_format_mode;
tiling_data->data_format_mode = data_format_mode;

auto tiling_tensor = TorchNpuHepler::CopyTensorHostToDevice(tiling_buffer);
return tiling_tensor;
}

HOST_API void catlass_matmul_basic(const at::Tensor &input_a, const at::Tensor &input_b, at::Tensor &output_c,
c10::optional<c10::string_view> format_mode)
{
// ops valid check
at::ScalarType aType = input_a.scalar_type();
at::ScalarType bType = input_b.scalar_type();
at::ScalarType cType = output_c.scalar_type();
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
TORCH_CHECK(
(aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half) || (aType == at::ScalarType::Float),
"tensor type only support half / bf16 / fp32");

auto formatMode = static_cast<WeightFormatMode>(GetModeVal(weightFormatMap, format_mode, "ND", "format_mode"));
TORCH_CHECK(formatMode == WeightFormatMode::WEIGHT_ND, "current ops only support weightFormat ND");

int32_t m = input_a.size(0);
int32_t k = input_a.size(1);
int32_t n = input_b.size(1);
TORCH_CHECK(input_b.size(0) == k, "input k dim shape mismatch");

uint32_t blockDim;
auto tiling_tensor = get_tiling(m, n, k, formatMode, dTypeMap[aType], blockDim);

// launch the kernel function via torch
auto workspace_tensor = at::empty({1}, at::TensorOptions().dtype(at::kByte).device(input_a.options().device()));
EXEC_KERNEL_CMD(catlass_matmul_basic, blockDim, input_a, input_b, output_c, workspace_tensor, tiling_tensor);
}

} // namespace npu_kernel
} // namespace sglang
35 changes: 35 additions & 0 deletions csrc/catlass/op_host/catlass_matmul_tiling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef KERNEL_CATLASS_MATMUL_TILING_H
#define KERNEL_CATLASS_MATMUL_TILING_H

#include <cstdint>

namespace sglang {
namespace npu_kernel {

typedef enum { WEIGHT_ND = 0, WEIGHT_NZ = 1 } WeightFormatMode;

typedef enum { BF16 = 0, FP16 = 1, FP32 = 2 } DataFormatMode;

struct KernelCatlassMatmulTilingData {
int32_t m;
int32_t n;
int32_t k;

int64_t weight_format_mode = WEIGHT_ND;
int64_t data_format_mode = BF16;
};

} // namespace npu_kernel
} // namespace sglang

#endif // KERNEL_CATLASS_MATMUL_TILING_H
94 changes: 94 additions & 0 deletions csrc/catlass/op_kernel/catlass_matmul_basic_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/* include file of ascendc */
#include "kernel_operator.h"
#include "../op_host/catlass_matmul_tiling.h"
/* include file of catlass */
#include "catlass/gemm/kernel/basic_matmul.hpp"

#include "catlass/arch/arch.hpp"
#include "catlass/catlass.hpp"
#include "catlass/gemm/block/block_mmad.hpp"
#include "catlass/gemm/block/block_swizzle.hpp"
#include "catlass/gemm/dispatch_policy.hpp"
#include "catlass/gemm/gemm_type.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/status.hpp"

using namespace Catlass;

extern "C" __global__ __aicore__ void catlass_matmul_basic(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, GM_ADDR gmWorkspace,
GM_ADDR gmTiling)
{
using ArchTag = Arch::AtlasA2;
using DispatchPolicy = Gemm::MmadAtlasA2Pingpong<true>;
// tile shape for different element size
using L1TileShape_2B = GemmShape<128, 256, 256>;
using L0TileShape_2B = GemmShape<128, 256, 64>;
using L1TileShape_4B = GemmShape<128, 128, 256>;
using L0TileShape_4B = GemmShape<128, 128, 64>;
using BlockEpilogue = void;
// Swizzle offset is 3 and direction is 0.
using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<3, 0>;

/* init catlass template 1. fp16 no_weight_nz */
using BlockMmad_case1 =
Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape_2B, L0TileShape_2B, Gemm::GemmType<half, layout::RowMajor>,
Gemm::GemmType<half, layout::RowMajor>, Gemm::GemmType<half, layout::RowMajor>>;
using MatmulKernel_fp16_no_nz = Gemm::Kernel::BasicMatmul<BlockMmad_case1, BlockEpilogue, BlockScheduler>;
/* init catlass template 2. bf16 no_weight_nz */
using BlockMmad_case2 =
Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape_2B, L0TileShape_2B, Gemm::GemmType<__bf16, layout::RowMajor>,
Gemm::GemmType<__bf16, layout::RowMajor>, Gemm::GemmType<__bf16, layout::RowMajor>>;
using MatmulKernel_bf16_no_nz = Gemm::Kernel::BasicMatmul<BlockMmad_case2, BlockEpilogue, BlockScheduler>;
/* init catlass template 3. fp32 no_weight_nz */
using BlockMmad_case3 =
Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape_4B, L0TileShape_4B, Gemm::GemmType<float, layout::RowMajor>,
Gemm::GemmType<float, layout::RowMajor>, Gemm::GemmType<float, layout::RowMajor>>;
using MatmulKernel_fp32_no_nz = Gemm::Kernel::BasicMatmul<BlockMmad_case3, BlockEpilogue, BlockScheduler>;

auto tiling_data = reinterpret_cast<__gm__ sglang::npu_kernel::KernelCatlassMatmulTilingData *>(gmTiling);
uint32_t m = tiling_data->m;
uint32_t n = tiling_data->n;
uint32_t k = tiling_data->k;

layout::RowMajor layoutA{m, k};
layout::RowMajor layoutB{k, n};
layout::RowMajor layoutC{m, n};

/* init catlass instance and run */
GemmCoord problemShape{m, n, k};
if (tiling_data->data_format_mode == sglang::npu_kernel::DataFormatMode::BF16) {
MatmulKernel_bf16_no_nz::Arguments arguments{problemShape, gmA, gmB, gmC};

typename MatmulKernel_bf16_no_nz::Params params{problemShape, gmA, layoutA, gmB, layoutB, gmC, layoutC};

MatmulKernel_bf16_no_nz matmul_kernel;
matmul_kernel(params);
} else if (tiling_data->data_format_mode == sglang::npu_kernel::DataFormatMode::FP16) {
MatmulKernel_fp16_no_nz::Arguments arguments{problemShape, gmA, gmB, gmC};

typename MatmulKernel_fp16_no_nz::Params params{problemShape, gmA, layoutA, gmB, layoutB, gmC, layoutC};

MatmulKernel_fp16_no_nz matmul_kernel;
matmul_kernel(params);
} else if (tiling_data->data_format_mode == sglang::npu_kernel::DataFormatMode::FP32) {
MatmulKernel_fp32_no_nz::Arguments arguments{problemShape, gmA, gmB, gmC};

typename MatmulKernel_fp32_no_nz::Params params{problemShape, gmA, layoutA, gmB, layoutB, gmC, layoutC};

MatmulKernel_fp32_no_nz matmul_kernel;
matmul_kernel(params);
} else {
// TODO: use device error check process
return;
}
}
8 changes: 8 additions & 0 deletions csrc/pytorch_extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ TORCH_LIBRARY_FRAGMENT(npu, m)

m.def(
"sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");

#ifdef BUILD_CATLASS_MODULE
m.def("catlass_matmul_basic(Tensor tensor_a, Tensor tensor_b, Tensor(a!) tensor_c, str? format_mode=None) -> ()");
#endif
}
} // namespace

Expand Down Expand Up @@ -109,5 +113,9 @@ TORCH_LIBRARY_IMPL(npu, PrivateUse1, m)
m.impl("sgmv_expand", TORCH_FN(sglang::npu_kernel::sgmv_expand));

m.impl("sgmv_shrink", TORCH_FN(sglang::npu_kernel::sgmv_shrink));

#ifdef BUILD_CATLASS_MODULE
m.impl("catlass_matmul_basic", TORCH_FN(sglang::npu_kernel::catlass_matmul_basic));
#endif
}
} // namespace
5 changes: 5 additions & 0 deletions include/sgl_kenel_npu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight,
void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices,
at::Tensor &seq_len, at::Tensor &y, double scale);

#ifdef BUILD_CATLASS_MODULE
void catlass_matmul_basic(const at::Tensor &tensor_a,
const at::Tensor &tensor_b, at::Tensor &tensor_c,
c10::optional<c10::string_view> format_mode);
#endif
} // namespace npu_kernel

} // namespace sglang
Expand Down
Loading