Skip to content

Commit ed8a96f

Browse files
committed
fix
1 parent 373ce3d commit ed8a96f

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

csrc/lightning_indexer/op_host/lightning_indexer.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace sglang::LIHost {
1515

1616
using namespace ge_helper;
1717
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
18+
constexpr uint32_t MAX_DECODE_BS = 512;
1819
// npu tensor max size
1920
constexpr int SIZE = 8;
2021
constexpr int DIM_0 = 0;
@@ -117,26 +118,33 @@ HOST_API at::Tensor lightning_indexer(const at::Tensor &query, const at::Tensor
117118
uint32_t tilingSize = sizeof(LITilingData);
118119
auto blockDim = tilingData.usedCoreNum;
119120
auto bs = tilingData.bSize;
120-
uint64_t mapKey = tilingData.tilingKey;
121-
mapKey = (mapKey << 32) | bs;
122-
//std::cout << "mapKey is " << mapKey << std::endl;
123-
124-
static auto globalTilingData = at::empty({tilingSize * MAX_CAPTURE_NUM},
125-
at::TensorOptions().dtype(at::kByte).device(query.options().device()));
126-
if (captureMap.find(mapKey) == captureMap.end()) {
127-
// std::cout << "step in, mapKey is " << mapKey << std::endl;
128-
TORCH_CHECK(actualCaptureNum < MAX_CAPTURE_NUM, "lightning_indexer captureNum overflow")
129-
captureMap[mapKey] = actualCaptureNum;
130-
aclrtMemcpy(globalTilingData.data_ptr<uint8_t>() + actualCaptureNum * tilingSize, tilingSize, &tilingData,
121+
at::Tensor tilingTensor;
122+
123+
if (bs > MAX_DECODE_BS) {
124+
static auto preillTilingBuffer = at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(query.options().device()));
125+
aclrtMemcpy(preillTilingBuffer.data_ptr<uint8_t>(), tilingSize, &tilingData,
131126
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
132-
actualCaptureNum++;
127+
tilingTensor = at::from_blob(preillTilingBuffer.data_ptr<uint8_t>(), tilingSize, at::kByte);
128+
} else {
129+
uint64_t mapKey = tilingData.tilingKey;
130+
mapKey = (mapKey << 32) | bs;
131+
132+
static auto globalTilingBuffer = at::empty({tilingSize * MAX_CAPTURE_NUM},
133+
at::TensorOptions().dtype(at::kByte).device(query.options().device()));
134+
if (captureMap.find(mapKey) == captureMap.end()) {
135+
// std::cout << "step in, mapKey is " << mapKey << std::endl;
136+
TORCH_CHECK(actualCaptureNum < MAX_CAPTURE_NUM, "lightning_indexer captureNum overflow")
137+
captureMap[mapKey] = actualCaptureNum;
138+
aclrtMemcpy(globalTilingBuffer.data_ptr<uint8_t>() + actualCaptureNum * tilingSize, tilingSize, &tilingData,
139+
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
140+
actualCaptureNum++;
141+
}
142+
tilingTensor =
143+
at::from_blob(globalTilingBuffer.data_ptr<uint8_t>() + (tilingSize * captureMap[mapKey]), tilingSize, at::kByte);
133144
}
134-
at::Tensor tilingTensor =
135-
at::from_blob(globalTilingData.data_ptr<uint8_t>() + (tilingSize * captureMap[mapKey]), tilingSize, at::kByte);
136145

137146
size_t userWorkspaceSize = *context->GetWorkspaceSizes(1);
138147
workspace = at::empty({userWorkspaceSize}, at::TensorOptions().dtype(at::kByte).device(query.options().device()));
139-
//std::cout << "6" << std::endl;
140148
EXEC_KERNEL_CMD(lightning_indexer, blockDim, query, key, weights, actual_seq_lengths_query, actual_seq_lengths_key,
141149
block_table, sparse_indices, workspace, tilingTensor);
142150
return sparse_indices;

0 commit comments

Comments
 (0)