Skip to content

Commit 8c2dbd7

Browse files
committed
[Feat] Lightning indexer op & GE helper engineering
1 parent d56b60b commit 8c2dbd7

24 files changed

+4418
-22
lines changed

csrc/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# set the library output dir to the python dir for wheel package build
22
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/python/sgl_kernel_npu/sgl_kernel_npu/lib)
3+
set(ASCEND_INCLUDE_DIR ${ASCEND_HOME_PATH}/aarch64-linux/include)
34

45
# host side files
56
FILE(GLOB OP_SRCS
@@ -17,6 +18,8 @@ FILE(GLOB OP_SRCS
1718
${PROJECT_OP_SRC_BASE}/lora/op_host/bgmv_shrink.cpp
1819
${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_expand.cpp
1920
${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_shrink.cpp
21+
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp
22+
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp
2023
)
2124

2225
# set the so name
@@ -38,6 +41,7 @@ ascendc_library(workspace_kernel STATIC
3841
${PROJECT_OP_SRC_BASE}/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
3942
${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp
4043
${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp
44+
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp
4145
)
4246

4347
ascendc_compile_definitions(workspace_kernel PRIVATE
@@ -71,4 +75,7 @@ target_include_directories(${OP_PLUGIN_NAME} PRIVATE
7175
${TORCH_DIR}/include
7276
${TORCH_DIR}/include/torch/csrc/api/include
7377
${TORCH_NPU_DIR}/include
78+
${ASCEND_INCLUDE_DIR}/external
79+
${ASCEND_INCLUDE_DIR}/experiment/platform
80+
${ASCEND_INCLUDE_DIR}/experiment/runtime
7481
)

csrc/alloc_extend/op_host/alloc_extend_tiling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ at::Tensor get_tiling(int32_t &block_dim, int32_t &workspace_size, const int64_t
3434
tiling_data->used_core_num = block_dim;
3535
tiling_data->total_extend_tokens = total_extend_tokens;
3636

37-
auto tiling_tensor = TorchNpuHepler::CopyTensorHostToDevice(tiling_buffer);
37+
auto tiling_tensor = TorchNpuHelper::CopyTensorHostToDevice(tiling_buffer);
3838
return tiling_tensor;
3939
}
4040

csrc/assign_cache_op/op_host/assign_cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ HOST_API at::Tensor GetTilingTensor(CustomAssignTilingData &tilingData, size_t t
2323
{
2424
auto buffer = at::empty({static_cast<int64_t>(tilingSize)}, at::kByte);
2525
tilingData.SetToBuffer(buffer.data_ptr<uint8_t>(), tilingSize);
26-
auto tilingTensor = TorchNpuHepler::CopyTensorHostToDevice(buffer);
26+
auto tilingTensor = TorchNpuHelper::CopyTensorHostToDevice(buffer);
2727
return tilingTensor;
2828
}
2929

@@ -57,7 +57,7 @@ HOST_API bool assign_cache_op(at::Tensor &dstTensor, const at::Tensor &srcTensor
5757
at::Tensor tiling = GetTilingTensor(tilingData, sizeof(tilingData));
5858

5959
auto sync = at::zeros({syncWorkspaceSize, 1}, at::kByte);
60-
auto syncDevice = TorchNpuHepler::CopyTensorHostToDevice(sync);
60+
auto syncDevice = TorchNpuHelper::CopyTensorHostToDevice(sync);
6161
EXEC_KERNEL_CMD(assign_cache_op, blockDim, dstTensor, srcTensor, dstStartIdx, dstEndIdx, srcStartIdx, srcEndIdx,
6262
syncDevice, tiling);
6363
return true;

csrc/build_tree/op_host/build_tree.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ at::Tensor get_tiling(int32_t &block_dim, int32_t &workspace_size, int32_t batch
4242
tiling_data->big_core_tile_num = (batch_size + block_dim - 1) / block_dim;
4343
tiling_data->small_core_tile_num = batch_size / block_dim;
4444

45-
auto tiling_tensor = TorchNpuHepler::CopyTensorHostToDevice(tiling_buffer);
45+
auto tiling_tensor = TorchNpuHelper::CopyTensorHostToDevice(tiling_buffer);
4646
return tiling_tensor;
4747
}
4848

csrc/cache_location_assign/op_host/cache_loc_assign.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ at::Tensor getTiling(const at::Tensor &reqPoolIndices, uint64_t rowSize, uint64_
7070
throw std::invalid_argument("Batch size is too large, buffer is not enough to do calculate");
7171
}
7272

73-
auto tilingTensor = TorchNpuHepler::CopyTensorHostToDevice(tilingBuffer);
73+
auto tilingTensor = TorchNpuHelper::CopyTensorHostToDevice(tilingBuffer);
7474
return tilingTensor;
7575
}
7676

csrc/lightning_indexer/README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# torch.ops.npu.lightning_indexer<a name="ZH-CN_TOPIC_0000001979260729"></a>
2+
3+
## Product Support Status <a name="zh-cn_topic_0000001832267082_section14441124184110"></a>
4+
| Product | Supported |
5+
| ------------------------------------------------------------ | :-------: |
6+
|<term>Atlas A3 Inference Product Series</term> ||
7+
8+
## Function Description<a name="zh-cn_topic_0000001832267082_section14441124184110"></a>
9+
10+
`LightningIndexer` computes the Top-$k$ positions corresponding to each token based on a series of operations. For an Index Query $Q_{index}\in\R^{g\times d}$ corresponding to a certain token, given the context Index Key $K_{index}\in\R^{S_{k}\times d},W\in\R^{g\times 1}$, where $g$ is the group size for GQA, $d$ is the dimension of each head, and $S_{k}$ is the context length, the specific calculation formula for `LightningIndexer` is as follows:
11+
$$
12+
\text{Top-}k\left\{[1]_{1\times g}@\left[(W@[1]_{1\times S_{k}})\odot\text{ReLU}\left(Q_{index}@K_{index}^T\right)\right]\right\}
13+
$$
14+
15+
## Function Prototype<a name="zh-cn_topic_0000001832267082_section45077510411"></a>
16+
17+
```
18+
torch.ops.npu.lightning_indexer(query, key, weights, actual_seq_lengths_query=None, actual_seq_lengths_key=None, block_table=None, layout_query='BSND', layout_key='BSND', sparse_count=2048, sparse_mode=3) -> Tensor
19+
```
20+
21+
## Parameter Description<a name="zh-cn_topic_0000001832267082_section112637109429"></a>
22+
23+
>**Note:**<br>
24+
>
25+
>- Dimension meanings for query, key, and weights parameters: B (Batch Size) represents the batch size of input samples, S (Sequence Length) represents the sequence length of input samples, H (Head Size) represents the size of the hidden layer, N (Head Num) represents the number of attention heads, D (Head Dim) represents the smallest unit dimension of the hidden layer, satisfying D=H/N, T represents the cumulative sum of sequence lengths for all batch input samples.
26+
>- S1 represents the S dimension in query shape, S2 represents the S dimension in key shape, N1 represents the N dimension in query shape, N2 represents the N dimension in key shape.
27+
28+
- **query** (`Tensor`): Required parameter, non-contiguous tensors not supported. Data layout supports ND format. Data types supported: `bfloat16` and `float16`.
29+
30+
- **key** (`Tensor`): Required parameter, non-contiguous tensors not supported. Data layout supports ND format. Data types supported: `bfloat16` and `float16`. When layout_key is 'PA_BSND', the shape is [block_count, block_size, N2, D], where block_count is the total number of blocks in PageAttention, and block_size is the number of tokens in one block.
31+
32+
- **weights** (`Tensor`): Required parameter, non-contiguous tensors not supported. Data layout supports ND format. Data types supported: `bfloat16` and `float16`. Supported input shapes: [B,S1,N1], [T,N1].
33+
34+
- <strong>*</strong>: Represents that parameters before it are position-dependent and must be provided in order (required parameters); parameters after it are keyword arguments, position-independent, and optional (default values will be used if not provided).
35+
36+
- **actual_seq_lengths_query** (`Tensor`): Optional parameter, represents the number of valid tokens for `query` in different batches. Data type supported: `int32`. If sequence length is not specified, None can be passed, indicating it's the same as the S dimension length of `query`'s shape.
37+
- The number of valid tokens for each batch in this parameter must not exceed the S dimension size in `query`. Supports a 1D tensor of length B. When `query`'s input_layout is 'TND', this parameter must be provided, and the number of elements in this parameter is used as the B value. Each element's value in this parameter represents the cumulative sum of tokens for the current batch and all previous batches (prefix sum), so the value of a later element must be >= the value of the previous element. Negative values are not allowed.
38+
39+
- **actual_seq_lengths_key** (`Tensor`): Optional parameter, represents the number of valid tokens for `key` in different batches. Data type supported: `int32`. If sequence length is not specified, None can be passed, indicating it's the same as the S dimension length of key's shape. Supports a 1D tensor of length B.
40+
41+
- **block_table** (`Tensor`): Optional parameter, represents the block mapping table used for KV storage in PageAttention. Data layout supports ND format. Data type supported: `int32`.
42+
- In PageAttention scenarios, block_table must be 2D, with the first dimension length equal to B, and the second dimension length not less than maxBlockNumPerSeq (maxBlockNumPerSeq is the maximum number of blocks corresponding to actual_seq_lengths_key for each batch).
43+
44+
- **layout_query** (`str`): Optional parameter, identifies the data layout format of input `query`. Currently supports: 'BSND', 'TND'. Default value: "BSND".
45+
46+
- **layout_key** (`str`): Optional parameter, identifies the data layout format of input `key`. Currently supports: 'PA_BSND', 'BSND', 'TND'. Default value: "BSND". In non-PageAttention scenarios, this parameter value should be consistent with **layout_query**.
47+
48+
- **sparse_count** (`int`): Optional parameter, represents the number of blocks to retain during the topK phase. Supports values 1-2048. Data type supported: `int32`.
49+
50+
- **sparse_mode** (`int`): Optional parameter, specifies the sparse mode. Supports values 0/3. Data type supported: `int32`.
51+
52+
- When sparse_mode is 0, it represents defaultMask mode.
53+
- When sparse_mode is 3, it represents rightDownCausal mode mask, corresponding to the lower triangular scenario divided by the right vertex.
54+
55+
## Return Value Description<a name="zh-cn_topic_0000001832267082_section22231435517"></a>
56+
57+
- **out** (`Tensor`): Output from the formula, data type supported: `int32`. Data layout supports ND format.
58+
59+
## Constraints<a name="zh-cn_topic_0000001832267082_section12345537164214"></a>
60+
61+
- This interface supports inference scenarios.
62+
- This interface supports graph mode.
63+
- When used with PyTorch, the versions of CANN-related packages and PyTorch-related packages must be compatible.
64+
- Parameter N in query supports 64, parameter N in key supports 1.
65+
- Parameter D in query and parameter D in key must be equal to 128.
66+
- Data types of parameters query, key, and weights must be consistent.
67+
- Supports block_size values that are multiples of 16, with maximum support up to 1024.
68+
69+
## Usage Example<a name="zh-cn_topic_0000001832267082_section14459801435"></a>
70+
71+
- See details in [test_lightning_indexer.py](../../tests/python/sgl_kernel_npu/test_lightning_indexer.py)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#include <cstdio>
2+
#include <string>
3+
#include "acl/acl.h"
4+
#include "kernel_tiling/kernel_tiling.h"
5+
#include "tiling/platform/platform_ascendc.h"
6+
#include "tiling/lightning_indexer_tiling.h"
7+
#include "defines.h"
8+
#include "torch_helper.h"
9+
#include "ge_helper.h"
10+
#include "common_tiling.h"
11+
#include "lightning_indexer_def.h"
12+
#include "common.h"
13+
#include "aclrtlaunch_lightning_indexer.h"
14+
15+
namespace sglang::LIHost {
16+
17+
using namespace ge_helper;
18+
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
19+
constexpr uint32_t MAX_DECODE_BS = 512;
20+
// npu tensor max size
21+
constexpr int SIZE = 8;
22+
constexpr int DIM_0 = 0;
23+
constexpr int DIM_1 = 1;
24+
constexpr int DIM_2 = 2;
25+
constexpr int DIM_3 = 3;
26+
27+
// namespace scope global parameters
28+
uint32_t actualCaptureNum = 0;
29+
static std::unordered_map<uint64_t, uint32_t> captureMap;
30+
// at::Tensor workspace;
31+
32+
inline at::Tensor ConstructLightningIndexerOutputTensor(const at::Tensor &query, const at::Tensor &key,
33+
const c10::optional<at::Tensor> &actual_seq_lengths_query,
34+
int64_t sparse_count, std::string query_layout_str,
35+
std::string key_layout_str)
36+
{
37+
at::SmallVector<int64_t, SIZE> outputSize;
38+
for (size_t i = 0; i < query.sizes().size(); i++) {
39+
TORCH_CHECK(query.size(i) > 0,
40+
"All values within query's shape should be greater "
41+
"than 0, but shape[",
42+
i, "] is ", query.size(i));
43+
}
44+
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
45+
46+
if (query_layout_str == "BSND") {
47+
outputSize = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count};
48+
} else {
49+
int n_dim_index = 0;
50+
n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2;
51+
outputSize = {query.size(DIM_0), key.size(n_dim_index), sparse_count};
52+
}
53+
at::Tensor output = at::empty(outputSize, query.options().dtype(at::kInt));
54+
55+
return output;
56+
}
57+
} // namespace sglang::LIHost
58+
59+
namespace sglang {
60+
namespace npu_kernel {
61+
HOST_API at::Tensor lightning_indexer(const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
62+
const c10::optional<at::Tensor> &actual_seq_lengths_query,
63+
const c10::optional<at::Tensor> &actual_seq_lengths_key,
64+
const c10::optional<at::Tensor> &block_table,
65+
c10::optional<c10::string_view> layout_query,
66+
c10::optional<c10::string_view> layout_key, c10::optional<int64_t> sparse_count,
67+
c10::optional<int64_t> sparse_mode)
68+
{
69+
using namespace LIHost;
70+
LightningIndexer indexer("lightning_indexer");
71+
auto context = std::make_shared<TilingContext>("lightning_indexer");
72+
TORCH_CHECK(context != nullptr, "TilingContext is null");
73+
74+
std::string layoutQuery(indexer.GetAttr(ATTR_QUERY_LAYOUT_INDEX).GetString());
75+
std::string layoutKey(indexer.GetAttr(ATTR_KEY_LAYOUT_INDEX).GetString());
76+
int64_t sparseCount = std::any_cast<int32_t>(indexer.GetAttr(ATTR_SPARSE_COUNT_INDEX).GetValue());
77+
78+
if (layout_query.has_value()) {
79+
layoutQuery = std::string(layout_query.value());
80+
indexer.SetAttrStr("layout_query", layoutQuery);
81+
}
82+
if (layout_key.has_value()) {
83+
layoutKey = std::string(layout_key.value());
84+
indexer.SetAttrStr("layout_key", layoutKey);
85+
}
86+
if (sparse_count.has_value()) {
87+
sparseCount = sparse_count.value();
88+
indexer.SetAttrAny("sparse_count", static_cast<int32_t>(sparseCount));
89+
}
90+
if (sparse_mode.has_value()) {
91+
indexer.SetAttrAny("sparse_mode", static_cast<int32_t>(sparse_mode.value()));
92+
}
93+
94+
at::Tensor sparse_indices = ConstructLightningIndexerOutputTensor(query, key, actual_seq_lengths_query, sparseCount,
95+
layoutQuery, layoutKey);
96+
97+
auto qScalarType = query.scalar_type();
98+
99+
at::Tensor actualSeqLengthsQuery =
100+
actual_seq_lengths_query.has_value()
101+
? actual_seq_lengths_query.value()
102+
: at::empty({1}, at::TensorOptions().dtype(qScalarType).device(query.options().device()));
103+
104+
at::Tensor actualSeqLengthsKey =
105+
actual_seq_lengths_key.has_value()
106+
? actual_seq_lengths_key.value()
107+
: at::empty({1}, at::TensorOptions().dtype(qScalarType).device(query.options().device()));
108+
109+
at::Tensor blockTable =
110+
block_table.has_value()
111+
? block_table.value()
112+
: at::empty({1}, at::TensorOptions().dtype(qScalarType).device(query.options().device()));
113+
114+
indexer.SetToContext(context, qScalarType);
115+
context->RegisterTensor(query, true);
116+
context->RegisterTensor(key, true);
117+
context->RegisterTensor(weights, true);
118+
context->RegisterTensor(actual_seq_lengths_query, true);
119+
context->RegisterTensor(actual_seq_lengths_key, true);
120+
context->RegisterTensor(block_table, true);
121+
context->RegisterTensor(sparse_indices, false);
122+
123+
LITilingInfo liInfo;
124+
LIInfoParser LIInfoParser(context.get());
125+
TORCH_CHECK(LIInfoParser.ParseAndCheck(liInfo) == ge::GRAPH_SUCCESS, "lightning_indexer ParseAndCheck failed")
126+
127+
LightningIndexerTiling liTiling(context.get());
128+
liTiling.DoTiling(&liInfo);
129+
const auto &tilingData = liTiling.GetTilingData();
130+
131+
uint32_t tilingSize = sizeof(LITilingData);
132+
auto blockDim = tilingData.usedCoreNum;
133+
auto bs = tilingData.bSize;
134+
at::Tensor tilingTensor;
135+
136+
auto tup =
137+
std::make_tuple(tilingData.bSize, tilingData.n2Size, tilingData.gSize, tilingData.s1Size, tilingData.s2Size,
138+
tilingData.blockSize, tilingData.maxBlockNumPerBatch, tilingData.tilingKey);
139+
auto hashValue = host_utils::TupleHasher::Hash(tup);
140+
141+
static auto globalTilingBuffer = at::empty({tilingSize * MAX_CAPTURE_NUM},
142+
at::TensorOptions().dtype(at::kByte).device(query.options().device()));
143+
144+
if (actualCaptureNum >= MAX_CAPTURE_NUM) {
145+
static auto preillTilingBuffer =
146+
at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(query.options().device()));
147+
aclrtMemcpy(preillTilingBuffer.data_ptr<uint8_t>(), tilingSize, &tilingData, tilingSize,
148+
ACL_MEMCPY_HOST_TO_DEVICE);
149+
tilingTensor = at::from_blob(preillTilingBuffer.data_ptr<uint8_t>(), tilingSize, at::kByte);
150+
} else if (captureMap.find(hashValue) != captureMap.end()) {
151+
// Decode replay phase and part of cached prefill tiling data got from globalTilingBuffer
152+
tilingTensor = at::from_blob(globalTilingBuffer.data_ptr<uint8_t>() + (tilingSize * captureMap[hashValue]),
153+
tilingSize, at::kByte);
154+
} else {
155+
// Captured tiling cached here
156+
captureMap[hashValue] = actualCaptureNum;
157+
aclrtMemcpy(globalTilingBuffer.data_ptr<uint8_t>() + actualCaptureNum * tilingSize, tilingSize, &tilingData,
158+
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
159+
actualCaptureNum++;
160+
tilingTensor = at::from_blob(globalTilingBuffer.data_ptr<uint8_t>() + (tilingSize * captureMap[hashValue]),
161+
tilingSize, at::kByte);
162+
}
163+
164+
size_t workspaceSize = context->GetWorkspaceSize();
165+
auto workspace = at::empty({workspaceSize}, at::TensorOptions().dtype(at::kByte).device(query.options().device()));
166+
EXEC_KERNEL_CMD(lightning_indexer, blockDim, query, key, weights, actualSeqLengthsQuery, actualSeqLengthsKey,
167+
blockTable, sparse_indices, workspace, tilingTensor);
168+
return sparse_indices;
169+
}
170+
} // namespace npu_kernel
171+
} // namespace sglang

0 commit comments

Comments
 (0)