Skip to content

Commit 8c683a5

Browse files
authored
Merge branch 'main' into a2_layour
2 parents 981bb22 + dbb3cc2 commit 8c683a5

File tree

35 files changed

+2362
-143
lines changed

35 files changed

+2362
-143
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ else ()
2727
add_compile_options(-g -rdynamic)
2828
endif ()
2929

30+
if(DEFINED ENV{DEBUG_MODE})
31+
if("$ENV{DEBUG_MODE}" STREQUAL "ON")
32+
add_compile_definitions(DEBUG_MODE)
33+
message(STATUS "Debug logging enabled from environment")
34+
endif()
35+
endif()
36+
3037
set(PROJECT_OP_SRC_BASE ${PROJECT_SOURCE_DIR}/csrc)
3138
set(PROJECT_BUILD_PATH ${PROJECT_SOURCE_DIR}/build)
3239
set(PROJECT_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/output)

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
11
# sgl-kernel-npu
2+
23
SGLang kernel library for NPU
34
Contribution guide refer to [Contribution Guide](docs/developer_guide/contribution_guide.md).
5+
6+
## Quick start
7+
8+
DeepEP-Ascend: Ascend Implementation of DeepEP. [README](https://github.com/sgl-project/sgl-kernel-npu/blob/main/python/deep_ep/README.md)
9+
10+
SGL-Kernel-NPU: Other SGLang Kernels for Ascend NPU. [README](https://github.com/sgl-project/sgl-kernel-npu/blob/main/python/sgl_kernel_npu/README.md)
11+
12+
## DeepEP-Ascend Performance
13+
14+
### Normal kernels with pure HCCS
15+
16+
We test normal kernels on A3 384 SuperPOD. And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-8 experts, INT8 dispatching and BF16 combining).
17+
18+
| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
19+
| --------- | ------------ | -------------------- | ----------- | -------------------- |
20+
| Intranode | 8 | 146 GB/s (HCCS) | 8 | 125 GB/s (HCCS) |
21+
| Intranode | 16 | 107 GB/s (HCCS) | 16 | 103 GB/s (HCCS) |
22+
| Intranode | 32 | 102 GB/s (HCCS) | 32 | 95 GB/s (HCCS) |
23+
| Intranode | 64 | 81 GB/s (HCCS) | 64 | 91 GB/s (HCCS) |
24+
| Intranode | 128 | 57 GB/s (HCCS) | 128 | 81 GB/s (HCCS) |
25+
26+
### Low-latency kernels with pure HCCS
27+
28+
We test normal kernels on A3 384 SuperPOD. And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, INT8 dispatching and BF16 combining).
29+
30+
| Dispatch #EP | Latency | Bandwidth | Combine #EP | Latency | Bandwidth |
31+
| ------------ | ------- | -------------- | ----------- | ------- | --------------- |
32+
| 8 | 132 us | 58 GB/s (HCCS) | 8 | 126 us | 116 GB/s (HCCS) |
33+
| 16 | 139 us | 55 GB/s (HCCS) | 16 | 135 us | 109 GB/s (HCCS) |
34+
| 32 | 153 us | 49 GB/s (HCCS) | 32 | 151 us | 97 GB/s (HCCS) |

build.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ ONLY_BUILD_DEEPEP_ADAPTER_MODULE="OFF"
99
ONLY_BUILD_DEEPEP_KERNELs_MODULE="OFF"
1010
ONLY_BUILD_MEMORY_SAVER_MODULE="OFF"
1111

12-
while getopts ":a:h" opt; do
12+
DEBUG_MODE="OFF"
13+
14+
while getopts ":a:hd" opt; do
1315
case ${opt} in
1416
a )
1517
BUILD_DEEPEP_MODULE="OFF"
@@ -41,6 +43,9 @@ while getopts ":a:h" opt; do
4143
;;
4244
esac
4345
;;
46+
d )
47+
DEBUG_MODE="ON"
48+
;;
4449
h )
4550
echo "Use './build.sh' build all modules."
4651
echo "Use './build.sh -a <target>' to build specific parts of the project."
@@ -67,6 +72,9 @@ done
6772

6873
shift $((OPTIND -1))
6974

75+
76+
export DEBUG_MODE=$DEBUG_MODE
77+
7078
SOC_VERSION="${1:-Ascend910_9382}"
7179

7280
if [ -n "$ASCEND_HOME_PATH" ]; then

contrib/torch_memory_saver/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ bash build.sh -a memory-saver
8585
2. Pip install the `.whl` file into your Python environment
8686

8787
```bash
88-
pip install output/deep_ep*.whl
88+
pip install output/torch_memory_saver*.whl
8989
```
9090
## Test
9191
You can use this command for local testing:

contrib/torch_memory_saver/python/torch_memory_saver/hooks/mode_preload.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from contextlib import contextmanager
44

5+
import torch
56
from torch_memory_saver.hooks.base import HookUtilBase
67
from torch_memory_saver.utils import get_binary_path_from_package
78

@@ -23,11 +24,17 @@ def get_path_binary(self):
2324
@contextmanager
2425
def configure_subprocess():
2526
"""Configure environment variables for subprocesses. Only needed for hook_mode=preload."""
26-
with _change_env(
27-
"LD_PRELOAD",
28-
str(get_binary_path_from_package("torch_memory_saver_hook_mode_preload")),
29-
):
27+
# Currently, torch_memory_saver does not support preload for npu, therefore LD_PRELOAD interception is not implemented.
28+
if hasattr(torch, "npu") and torch.npu.is_available():
3029
yield
30+
return
31+
32+
else:
33+
with _change_env(
34+
"LD_PRELOAD",
35+
str(get_binary_path_from_package("torch_memory_saver_hook_mode_preload")),
36+
):
37+
yield
3138

3239

3340
@contextmanager

csrc/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ FILE(GLOB OP_SRCS
99
${PROJECT_OP_SRC_BASE}/alloc_extend/op_host/alloc_extend_tiling.cpp
1010
${PROJECT_OP_SRC_BASE}/assign_cache_op/op_host/assign_cache.cpp
1111
${PROJECT_OP_SRC_BASE}/mla_preprocess/op_host/mla_preprocess.cpp
12+
${PROJECT_OP_SRC_BASE}/batch_matmul_transpose/op_host/batch_matmul_transpose.cpp
13+
${PROJECT_OP_SRC_BASE}/batch_matmul_transpose/op_host/tiling/tiling_data.cpp
1214
)
1315

1416
# set the so name
@@ -19,6 +21,7 @@ ascendc_library(no_workspace_kernel STATIC
1921
${PROJECT_OP_SRC_BASE}/helloworld/op_kernel/kernel_helloworld.cpp
2022
${PROJECT_OP_SRC_BASE}/cache_location_assign/op_kernel/cache_loc_assign_kernel.cpp
2123
${PROJECT_OP_SRC_BASE}/assign_cache_op/op_kernel/assign_cache_op.cpp
24+
${PROJECT_OP_SRC_BASE}/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
2225
)
2326

2427
ascendc_library(workspace_kernel STATIC
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#include <iostream>
2+
#include <string>
3+
#include "acl/acl.h"
4+
#include "kernel_tiling/kernel_tiling.h"
5+
#include "tiling/platform/platform_ascendc.h"
6+
#include "tiling/tiling_data.h"
7+
#include "defines.h"
8+
#include "torch_helper.h"
9+
#include "common_tiling.h"
10+
#include "aclrtlaunch_batch_matmul_transpose.h"
11+
12+
namespace sglang {
13+
namespace npu_kernel {
14+
using namespace pp_matmul;
15+
16+
std::unordered_map<c10::string_view, uint16_t> quantModeMap = {
17+
{"per_channel_symm", 0},
18+
{"per_channel_asymm", 1},
19+
{"per_token_symm", 2},
20+
};
21+
22+
std::unordered_map<c10::string_view, uint16_t> formatModeMap = {
23+
{"ND", 0},
24+
{"NZ", 1},
25+
};
26+
27+
std::unordered_map<c10::ScalarType, TensorDType> atType2tensorDType = {
28+
{at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16},
29+
{at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}};
30+
31+
// batch size -> memory index
32+
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
33+
34+
template <typename MapType>
35+
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
36+
const char *mode_name)
37+
{
38+
std::string modeStr(mode_name);
39+
c10::string_view mode_str = mode_opt.value_or(default_mode);
40+
auto it = mode_map.find(mode_str);
41+
// if input mode is unsupported, use default value
42+
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
43+
return it->second;
44+
}
45+
46+
HOST_API void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
47+
c10::optional<c10::string_view> format_mode,
48+
c10::optional<c10::string_view> quant_mode)
49+
{
50+
auto tensorAShape = tensor_a.sizes();
51+
auto tensorBShape = tensor_b.sizes();
52+
auto tensorCShape = tensor_c.sizes();
53+
uint32_t n;
54+
uint32_t block_dim;
55+
HardwareInfo hwInfo;
56+
std::map<c10::ScalarType, float> dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}};
57+
58+
at::ScalarType aType = tensor_a.scalar_type();
59+
at::ScalarType bType = tensor_b.scalar_type();
60+
at::ScalarType cType = tensor_c.scalar_type();
61+
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
62+
TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half),
63+
"tensor type only support half or bf16");
64+
65+
TensorFormat formatMode = static_cast<TensorFormat>(GetModeVal(formatModeMap, format_mode, "ND", "format_mode"));
66+
MatMul::QuantMode quantMode =
67+
static_cast<MatMul::QuantMode>(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode"));
68+
69+
TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor");
70+
if (formatMode == TensorFormat::TENSOR_FORMAT_ND) {
71+
TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format");
72+
TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong");
73+
n = tensorBShape[2];
74+
} else {
75+
TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format");
76+
TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong");
77+
n = tensorBShape[1] * tensorBShape[3];
78+
}
79+
TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong");
80+
81+
OpShape opShape = {.batchSize = static_cast<uint32_t>(tensorAShape[1]),
82+
.m = static_cast<uint32_t>(tensorAShape[0]),
83+
.k = static_cast<uint32_t>(tensorAShape[2]),
84+
.n = n};
85+
PpMatmulTilingData matmulTilingData = {
86+
.opShape = opShape,
87+
};
88+
auto dType = atType2tensorDType[aType];
89+
MatMulInfo mmInfo = {.batchSize = opShape.batchSize,
90+
.m = opShape.m,
91+
.k = opShape.k,
92+
.n = opShape.n,
93+
.dtypeA = dType,
94+
.dtypeB = dType,
95+
.dtypeC = dType,
96+
.formatB = formatMode,
97+
.mmType = MatMul::MatMulType::MATMUL_EIN_SUM,
98+
.inDtype = dTypeMap[aType],
99+
.outDtype = dTypeMap[cType],
100+
.quantMode = quantMode};
101+
GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData);
102+
host_utils::PpMatmulTilingCheck(matmulTilingData);
103+
104+
// tiling
105+
int32_t batchIdx = opShape.m - 1;
106+
uint32_t tilingSize = sizeof(PpMatmulTilingData);
107+
static auto global_tiling_data = at::empty(
108+
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
109+
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
110+
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
111+
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
112+
} else {
113+
// Handle the case where batchIdx is out of range
114+
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
115+
}
116+
at::Tensor tiling_tensor =
117+
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);
118+
119+
EXEC_KERNEL_CMD(batch_matmul_transpose, block_dim, tensor_a, tensor_b, tensor_c, tiling_tensor);
120+
}
121+
122+
} // namespace npu_kernel
123+
124+
} // namespace sglang

0 commit comments

Comments
 (0)