@@ -15,6 +15,7 @@ namespace sglang::LIHost {
1515
1616using namespace ge_helper ;
1717constexpr uint32_t MAX_CAPTURE_NUM = 1024 ;
18+ constexpr uint32_t MAX_DECODE_BS = 512 ;
1819// npu tensor max size
1920constexpr int SIZE = 8 ;
2021constexpr int DIM_0 = 0 ;
@@ -117,26 +118,33 @@ HOST_API at::Tensor lightning_indexer(const at::Tensor &query, const at::Tensor
117118 uint32_t tilingSize = sizeof (LITilingData);
118119 auto blockDim = tilingData.usedCoreNum ;
119120 auto bs = tilingData.bSize ;
120- uint64_t mapKey = tilingData.tilingKey ;
121- mapKey = (mapKey << 32 ) | bs;
122- // std::cout << "mapKey is " << mapKey << std::endl;
123-
124- static auto globalTilingData = at::empty ({tilingSize * MAX_CAPTURE_NUM},
125- at::TensorOptions ().dtype (at::kByte ).device (query.options ().device ()));
126- if (captureMap.find (mapKey) == captureMap.end ()) {
127- // std::cout << "step in, mapKey is " << mapKey << std::endl;
128- TORCH_CHECK (actualCaptureNum < MAX_CAPTURE_NUM, " lightning_indexer captureNum overflow" )
129- captureMap[mapKey] = actualCaptureNum;
130- aclrtMemcpy (globalTilingData.data_ptr <uint8_t >() + actualCaptureNum * tilingSize, tilingSize, &tilingData,
121+ at::Tensor tilingTensor;
122+
123+ if (bs > MAX_DECODE_BS) {
124+ static auto preillTilingBuffer = at::empty ({tilingSize}, at::TensorOptions ().dtype (at::kByte ).device (query.options ().device ()));
125+ aclrtMemcpy (preillTilingBuffer.data_ptr <uint8_t >(), tilingSize, &tilingData,
131126 tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
132- actualCaptureNum++;
127+ tilingTensor = at::from_blob (preillTilingBuffer.data_ptr <uint8_t >(), tilingSize, at::kByte );
128+ } else {
129+ uint64_t mapKey = tilingData.tilingKey ;
130+ mapKey = (mapKey << 32 ) | bs;
131+
132+ static auto globalTilingBuffer = at::empty ({tilingSize * MAX_CAPTURE_NUM},
133+ at::TensorOptions ().dtype (at::kByte ).device (query.options ().device ()));
134+ if (captureMap.find (mapKey) == captureMap.end ()) {
135+ // std::cout << "step in, mapKey is " << mapKey << std::endl;
136+ TORCH_CHECK (actualCaptureNum < MAX_CAPTURE_NUM, " lightning_indexer captureNum overflow" )
137+ captureMap[mapKey] = actualCaptureNum;
138+ aclrtMemcpy (globalTilingBuffer.data_ptr <uint8_t >() + actualCaptureNum * tilingSize, tilingSize, &tilingData,
139+ tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
140+ actualCaptureNum++;
141+ }
142+ tilingTensor =
143+ at::from_blob (globalTilingBuffer.data_ptr <uint8_t >() + (tilingSize * captureMap[mapKey]), tilingSize, at::kByte );
133144 }
134- at::Tensor tilingTensor =
135- at::from_blob (globalTilingData.data_ptr <uint8_t >() + (tilingSize * captureMap[mapKey]), tilingSize, at::kByte );
136145
137146 size_t userWorkspaceSize = *context->GetWorkspaceSizes (1 );
138147 workspace = at::empty ({userWorkspaceSize}, at::TensorOptions ().dtype (at::kByte ).device (query.options ().device ()));
139- // std::cout << "6" << std::endl;
140148 EXEC_KERNEL_CMD (lightning_indexer, blockDim, query, key, weights, actual_seq_lengths_query, actual_seq_lengths_key,
141149 block_table, sparse_indices, workspace, tilingTensor);
142150 return sparse_indices;
0 commit comments