diff --git a/CMakeLists.txt b/CMakeLists.txt index 36d4a9a2..df301608 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ option(BUILD_TESTS "build test or not" OFF) option(BUILD_DEEPEP_MODULE "build deepep" ON) option(BUILD_KERNELS_MODULE "build kernels" ON) +option(BUILD_CATLASS_MODULE "build catlass ops within kernels" OFF) set(CMAKE_CXX_STANDARD 17) #set(CMAKE_VERBOSE_MAKEFILE ON) @@ -17,6 +18,12 @@ endif() add_compile_options(-hno-unused-parameter -lno-unused-function -Wunused-value -Wcast-align) add_compile_options(-Wcast-qual -Winvalid-pch -Wwrite-strings -Wsign-compare -Wextra) +if (BUILD_CATLASS_MODULE) + add_compile_definitions(BUILD_CATLASS_MODULE) + set(CATLASS_DIR "${PROJECT_SOURCE_DIR}/3rdparty/catlass") # specific your catlass path here + message(STATUS "[CATLASS] ${CATLASS_DIR}") +endif () + if (${CMAKE_BUILD_TYPE} MATCHES "RELEASE") add_compile_options(-O3) add_compile_options(-fvisibility=hidden -fvisibility-inlines-hidden) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9646dc70..75d30a5a 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,6 +18,11 @@ FILE(GLOB OP_SRCS ${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_expand.cpp ${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_shrink.cpp ) +if(BUILD_CATLASS_MODULE) + list(APPEND OP_SRCS + ${PROJECT_OP_SRC_BASE}/catlass/op_host/catlass_matmul_basic.cpp + ) +endif() # set the so name set(OP_PLUGIN_NAME sgl_kernel_npu) @@ -34,11 +39,23 @@ ascendc_library(no_workspace_kernel STATIC ${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgmv_shrink_kernel.cpp ) -ascendc_library(workspace_kernel STATIC +# kernel side files with workspace +set(WORKSPACE_KERNEL_SRCS ${PROJECT_OP_SRC_BASE}/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp ${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp ${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp ) +if(BUILD_CATLASS_MODULE) + list(APPEND WORKSPACE_KERNEL_SRCS + ${PROJECT_OP_SRC_BASE}/catlass/op_kernel/catlass_matmul_basic_kernel.cpp + ) +endif() +ascendc_library(workspace_kernel STATIC ${WORKSPACE_KERNEL_SRCS}) +if(BUILD_CATLASS_MODULE) + ascendc_include_directories(workspace_kernel PRIVATE + ${CATLASS_DIR}/include + ) +endif() ascendc_compile_definitions(workspace_kernel PRIVATE -DHAVE_WORKSPACE diff --git a/csrc/catlass/README.md b/csrc/catlass/README.md new file mode 100644 index 00000000..b2e76aad --- /dev/null +++ b/csrc/catlass/README.md @@ -0,0 +1,78 @@ +# torch.ops.catlass_matmul_basic + + +## Function Description | 功能描述 + +### English: +This is the catlass version matmul kernel `catlass_matmul_basic` kernel function + +### 中文: +这是调用catlass模板库实现的矩阵乘法运算 `catlass_matmul_basic`内核方法 + +参考/Refs: [CATLSS](https://gitcode.com/cann/catlass) + + +## Interface Prototype | 接口原型 + +### Python Binding Definition +```python +import sgl_kernel_npu + +torch.ops.npu.catlass_matmul_basic( + input_a: torch.Tensor, # bf16/fp16/fp32, [m, k] + input_b: torch.Tensor, # bf16/fp16/fp32, [k, n] + output_c: torch.Tensor, # bf16/fp16/fp32, [m, n] + format_mode: str = None # string "ND"/"NZ" +) -> None +``` + +### Kernel Definition | 核函数定义 +```C++ +extern "C" __global__ __aicore__ void catlass_matmul_basic(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, + GM_ADDR gmWorkspace, GM_ADDR gmTiling) +``` + +## Parameter Description | 参数说明 + +| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 | +|:----------------------|:----------------|:-------------------------------------|:--------------| +| `input_a` | `torch.Tensor` | input left tensor with shape (m, k) | 左输入矩阵,(m,k)大小 | +| `input_b` | `torch.Tensor` | input right tensor with shape (k, n) | 右输入矩阵,(k,n)大小 | +| `format_mode` | `[optional]string` | weight format ND/NZ, default ND | 权重格式ND/NZ, 默认为 ND + + +## Output Description | 输出说明 + +| Parameter Name (参数名称) | DataType (数据类型) | Description | 说明 | +|:-----------------------|:----------------|:--------------------------------|:--------------| +| `output_c` | `torch.Tensor` | output tensor with shape (m, n) | 输出矩阵,(m, n)大小 | + + +## Constraints | 约束说明 + +### English: +`format_mode = "NZ"` is not implemented + +### 中文: +`format_mode = "NZ"` 暂未实现 + +## Example | 调用示例 + +```python +import math +import sgl_kernel_npu +import torch +import torch_npu + +device = torch.device('npu:0') + +dtypes = [torch.float16, torch.bfloat16, torch.float32] +m, k, n = 128, 256, 256 + +for dtype in dtypes: + a = torch.rand(m, k, dtype=dtype, device="npu") + b = torch.rand(k, n, dtype=dtype, device="npu") + res = torch.empty((m, n), dtype=dtype, device="npu") + + torch.ops.npu.catlass_matmul_basic(a, b, res) +``` diff --git a/csrc/catlass/op_host/catlass_matmul_basic.cpp b/csrc/catlass/op_host/catlass_matmul_basic.cpp new file mode 100644 index 00000000..e364c736 --- /dev/null +++ b/csrc/catlass/op_host/catlass_matmul_basic.cpp @@ -0,0 +1,93 @@ +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "defines.h" +#include "tiling/platform/platform_ascendc.h" +#include "torch_helper.h" +#include "catlass_matmul_tiling.h" +#include "aclrtlaunch_catlass_matmul_basic.h" + +namespace sglang { +namespace npu_kernel { + +constexpr uint32_t PADDING_BYTE = 32U; + +std::map dTypeMap = {{at::ScalarType::Half, DataFormatMode::FP16}, + {at::ScalarType::BFloat16, DataFormatMode::BF16}, + {at::ScalarType::Float, DataFormatMode::FP32}}; + +std::unordered_map weightFormatMap = {{"ND", WeightFormatMode::WEIGHT_ND}, + {"NZ", WeightFormatMode::WEIGHT_NZ}}; + +template +inline int GetModeVal(const MapType &mode_map, c10::optional mode_opt, c10::string_view default_mode, + const char *mode_name) +{ + std::string modeStr(mode_name); + c10::string_view mode_str = mode_opt.value_or(default_mode); + auto it = mode_map.find(mode_str); + // if input mode is unsupported, use default value + TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str)); + return it->second; +} + +at::Tensor get_tiling(int32_t &m, int32_t &n, int32_t k, int64_t weight_format_mode, int64_t data_format_mode, + uint32_t &blockDim) +{ + auto ascendc_platform = platform_ascendc::PlatformAscendCManager::GetInstance(); + blockDim = static_cast(ascendc_platform->GetCoreNumAiv()); + + // align to 32 bytes + int32_t tiling_size = (sizeof(KernelCatlassMatmulTilingData) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE; + auto tiling_buffer = at::empty({tiling_size}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + + KernelCatlassMatmulTilingData *tiling_data = + reinterpret_cast(tiling_buffer.data_ptr()); + tiling_data->m = m; + tiling_data->n = n; + tiling_data->k = k; + tiling_data->weight_format_mode = weight_format_mode; + tiling_data->data_format_mode = data_format_mode; + + auto tiling_tensor = TorchNpuHepler::CopyTensorHostToDevice(tiling_buffer); + return tiling_tensor; +} + +HOST_API void catlass_matmul_basic(const at::Tensor &input_a, const at::Tensor &input_b, at::Tensor &output_c, + c10::optional format_mode) +{ + // ops valid check + at::ScalarType aType = input_a.scalar_type(); + at::ScalarType bType = input_b.scalar_type(); + at::ScalarType cType = output_c.scalar_type(); + TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same"); + TORCH_CHECK( + (aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half) || (aType == at::ScalarType::Float), + "tensor type only support half / bf16 / fp32"); + + auto formatMode = static_cast(GetModeVal(weightFormatMap, format_mode, "ND", "format_mode")); + TORCH_CHECK(formatMode == WeightFormatMode::WEIGHT_ND, "current ops only support weightFormat ND"); + + int32_t m = input_a.size(0); + int32_t k = input_a.size(1); + int32_t n = input_b.size(1); + TORCH_CHECK(input_b.size(0) == k, "input k dim shape mismatch"); + + uint32_t blockDim; + auto tiling_tensor = get_tiling(m, n, k, formatMode, dTypeMap[aType], blockDim); + + // launch the kernel function via torch + auto workspace_tensor = at::empty({1}, at::TensorOptions().dtype(at::kByte).device(input_a.options().device())); + EXEC_KERNEL_CMD(catlass_matmul_basic, blockDim, input_a, input_b, output_c, workspace_tensor, tiling_tensor); +} + +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/catlass/op_host/catlass_matmul_tiling.h b/csrc/catlass/op_host/catlass_matmul_tiling.h new file mode 100644 index 00000000..1a8b9691 --- /dev/null +++ b/csrc/catlass/op_host/catlass_matmul_tiling.h @@ -0,0 +1,35 @@ +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef KERNEL_CATLASS_MATMUL_TILING_H +#define KERNEL_CATLASS_MATMUL_TILING_H + +#include + +namespace sglang { +namespace npu_kernel { + +typedef enum { WEIGHT_ND = 0, WEIGHT_NZ = 1 } WeightFormatMode; + +typedef enum { BF16 = 0, FP16 = 1, FP32 = 2 } DataFormatMode; + +struct KernelCatlassMatmulTilingData { + int32_t m; + int32_t n; + int32_t k; + + int64_t weight_format_mode = WEIGHT_ND; + int64_t data_format_mode = BF16; +}; + +} // namespace npu_kernel +} // namespace sglang + +#endif // KERNEL_CATLASS_MATMUL_TILING_H diff --git a/csrc/catlass/op_kernel/catlass_matmul_basic_kernel.cpp b/csrc/catlass/op_kernel/catlass_matmul_basic_kernel.cpp new file mode 100644 index 00000000..34cada81 --- /dev/null +++ b/csrc/catlass/op_kernel/catlass_matmul_basic_kernel.cpp @@ -0,0 +1,94 @@ +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* include file of ascendc */ +#include "kernel_operator.h" +#include "../op_host/catlass_matmul_tiling.h" +/* include file of catlass */ +#include "catlass/gemm/kernel/basic_matmul.hpp" + +#include "catlass/arch/arch.hpp" +#include "catlass/catlass.hpp" +#include "catlass/gemm/block/block_mmad.hpp" +#include "catlass/gemm/block/block_swizzle.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/status.hpp" + +using namespace Catlass; + +extern "C" __global__ __aicore__ void catlass_matmul_basic(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, GM_ADDR gmWorkspace, + GM_ADDR gmTiling) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = Gemm::MmadAtlasA2Pingpong; + // tile shape for different element size + using L1TileShape_2B = GemmShape<128, 256, 256>; + using L0TileShape_2B = GemmShape<128, 256, 64>; + using L1TileShape_4B = GemmShape<128, 128, 256>; + using L0TileShape_4B = GemmShape<128, 128, 64>; + using BlockEpilogue = void; + // Swizzle offset is 3 and direction is 0. + using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<3, 0>; + + /* init catlass template 1. fp16 no_weight_nz */ + using BlockMmad_case1 = + Gemm::Block::BlockMmad, + Gemm::GemmType, Gemm::GemmType>; + using MatmulKernel_fp16_no_nz = Gemm::Kernel::BasicMatmul; + /* init catlass template 2. bf16 no_weight_nz */ + using BlockMmad_case2 = + Gemm::Block::BlockMmad, + Gemm::GemmType<__bf16, layout::RowMajor>, Gemm::GemmType<__bf16, layout::RowMajor>>; + using MatmulKernel_bf16_no_nz = Gemm::Kernel::BasicMatmul; + /* init catlass template 3. fp32 no_weight_nz */ + using BlockMmad_case3 = + Gemm::Block::BlockMmad, + Gemm::GemmType, Gemm::GemmType>; + using MatmulKernel_fp32_no_nz = Gemm::Kernel::BasicMatmul; + + auto tiling_data = reinterpret_cast<__gm__ sglang::npu_kernel::KernelCatlassMatmulTilingData *>(gmTiling); + uint32_t m = tiling_data->m; + uint32_t n = tiling_data->n; + uint32_t k = tiling_data->k; + + layout::RowMajor layoutA{m, k}; + layout::RowMajor layoutB{k, n}; + layout::RowMajor layoutC{m, n}; + + /* init catlass instance and run */ + GemmCoord problemShape{m, n, k}; + if (tiling_data->data_format_mode == sglang::npu_kernel::DataFormatMode::BF16) { + MatmulKernel_bf16_no_nz::Arguments arguments{problemShape, gmA, gmB, gmC}; + + typename MatmulKernel_bf16_no_nz::Params params{problemShape, gmA, layoutA, gmB, layoutB, gmC, layoutC}; + + MatmulKernel_bf16_no_nz matmul_kernel; + matmul_kernel(params); + } else if (tiling_data->data_format_mode == sglang::npu_kernel::DataFormatMode::FP16) { + MatmulKernel_fp16_no_nz::Arguments arguments{problemShape, gmA, gmB, gmC}; + + typename MatmulKernel_fp16_no_nz::Params params{problemShape, gmA, layoutA, gmB, layoutB, gmC, layoutC}; + + MatmulKernel_fp16_no_nz matmul_kernel; + matmul_kernel(params); + } else if (tiling_data->data_format_mode == sglang::npu_kernel::DataFormatMode::FP32) { + MatmulKernel_fp32_no_nz::Arguments arguments{problemShape, gmA, gmB, gmC}; + + typename MatmulKernel_fp32_no_nz::Params params{problemShape, gmA, layoutA, gmB, layoutB, gmC, layoutC}; + + MatmulKernel_fp32_no_nz matmul_kernel; + matmul_kernel(params); + } else { + // TODO: use device error check process + return; + } +} diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp index 1411a907..fdd7e413 100644 --- a/csrc/pytorch_extensions.cpp +++ b/csrc/pytorch_extensions.cpp @@ -78,6 +78,10 @@ TORCH_LIBRARY_FRAGMENT(npu, m) m.def( "sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()"); + +#ifdef BUILD_CATLASS_MODULE + m.def("catlass_matmul_basic(Tensor tensor_a, Tensor tensor_b, Tensor(a!) tensor_c, str? format_mode=None) -> ()"); +#endif } } // namespace @@ -109,5 +113,9 @@ TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) m.impl("sgmv_expand", TORCH_FN(sglang::npu_kernel::sgmv_expand)); m.impl("sgmv_shrink", TORCH_FN(sglang::npu_kernel::sgmv_shrink)); + +#ifdef BUILD_CATLASS_MODULE + m.impl("catlass_matmul_basic", TORCH_FN(sglang::npu_kernel::catlass_matmul_basic)); +#endif } } // namespace diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h index 753fa100..9c75e3bf 100644 --- a/include/sgl_kenel_npu_ops.h +++ b/include/sgl_kenel_npu_ops.h @@ -89,6 +89,11 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, at::Tensor &y, double scale); +#ifdef BUILD_CATLASS_MODULE +void catlass_matmul_basic(const at::Tensor &tensor_a, + const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode); +#endif } // namespace npu_kernel } // namespace sglang diff --git a/tests/python/sgl_kernel_npu/test_catlass_matmul_basic.py b/tests/python/sgl_kernel_npu/test_catlass_matmul_basic.py new file mode 100644 index 00000000..58ee445a --- /dev/null +++ b/tests/python/sgl_kernel_npu/test_catlass_matmul_basic.py @@ -0,0 +1,135 @@ +import random +import time +import unittest + +import numpy as np +import sgl_kernel_npu +import torch +import torch_npu + +torch.set_printoptions(threshold=float("inf")) + + +class TestMatrixMultiplication(unittest.TestCase): + + def compute_golden(self, a, b, res1): + """Compute reference result (golden)""" + torch.matmul(a, b, out=res1) + + def assert_tensors_almost_equal(self, actual, expected, dtype): + """Check if two tensors are approximately equal (considering floating point errors)""" + self.assertEqual(actual.shape, expected.shape, "Shape mismatch") + + # Check for NaN + self.assertFalse(torch.isnan(actual).any(), "Actual result contains NaN") + self.assertFalse(torch.isnan(expected).any(), "Expected result contains NaN") + + # Check for Inf + self.assertFalse(torch.isinf(actual).any(), "Actual result contains Inf") + self.assertFalse(torch.isinf(expected).any(), "Expected result contains Inf") + + # Set different tolerances based on data type + if dtype == torch.float16: + rtol, atol = 5e-4, 5e-4 + else: # bfloat16 + rtol, atol = 1e-3, 1e-3 + + # Compare values + diff = torch.abs(actual - expected) + max_diff = diff.max().item() + max_expected = torch.abs(expected).max().item() + + # Check relative and absolute errors + if max_expected > 0: + relative_diff = max_diff / max_expected + self.assertLessEqual( + relative_diff, + rtol, + f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}", + ) + + self.assertLessEqual( + max_diff, atol, f"Absolute error too large: {max_diff} > {atol}" + ) + + def test_boundary_conditions(self): + """Test boundary conditions""" + test_cases = [ + # (m, k, n) + (1, 1, 1), # Minimum size + (10, 1, 1), # b=1 + (1, 1, 10), # m=1 + (5, 1, 5), # k=1 + (2, 2, 1), # n=1 + (1, 1, 100), # Flat case + (100, 100, 1), # Flat case + (3, 4, 5), # Random small size + (20, 30, 40), # Medium size + (128, 512, 128), # target case + (160, 512, 128), + ] + + dtypes = [torch.float16, torch.bfloat16, torch.float32] + + for dtype in dtypes: + for m, k, n in test_cases: + with self.subTest(dtype=dtype, shape=f"({m}, {k}, {n})"): + a = torch.randn(m, k, dtype=dtype, device="npu") + b_tensor = torch.randn(k, n, dtype=dtype, device="npu") + res1 = torch.empty((m, n), dtype=dtype, device="npu") + res2 = torch.empty((m, n), dtype=dtype, device="npu") + + self.compute_golden(a, b_tensor, res1) + torch.ops.npu.catlass_matmul_basic(a, b_tensor, res2) + self.assert_tensors_almost_equal(res1, res2, dtype) + + def test_random_shapes(self): + """Test randomly generated shapes""" + num_tests = 1 + dtypes = [torch.float16, torch.bfloat16, torch.float32] + + for dtype in dtypes: + for _ in range(num_tests): + # Generate reasonable random sizes + m = random.randint(1, 500) + k = random.randint(1, 500) + n = random.randint(1, 500) + + with self.subTest(dtype=dtype, shape=f"Random ({m}, {k}, {n})"): + a = torch.randn(m, k, dtype=dtype, device="npu") + b_tensor = torch.randn(k, n, dtype=dtype, device="npu") + res1 = torch.empty((m, n), dtype=dtype, device="npu") + res2 = torch.empty((m, n), dtype=dtype, device="npu") + + self.compute_golden(a, b_tensor, res1) + torch.ops.npu.catlass_matmul_basic(a, b_tensor, res2) + self.assert_tensors_almost_equal(res1, res2, dtype) + + def test_zero_values(self): + """Test zero input values""" + dtypes = [torch.float16, torch.bfloat16, torch.float32] + m, k, n = 4, 3, 2 + + for dtype in dtypes: + with self.subTest(dtype=dtype): + a = torch.zeros(m, k, dtype=dtype, device="npu") + b_tensor = torch.zeros(k, n, dtype=dtype, device="npu") + res1 = torch.empty((m, n), dtype=dtype, device="npu") + res2 = torch.empty((m, n), dtype=dtype, device="npu") + + self.compute_golden(a, b_tensor, res1) + torch.ops.npu.catlass_matmul_basic(a, b_tensor, res2) + self.assert_tensors_almost_equal(res1, res2, dtype) + self.assertTrue(torch.all(res2 == 0)) + + +if __name__ == "__main__": + try: + catlass_ops = torch.ops.npu.catlass_matmul_basic + except Exception as e: + print( + "use catlass ops in sglang-kernel need to set BUILD_KERNELS_MODULE in cmake during compiling" + ) + raise e + + unittest.main(verbosity=2)