Skip to content

Commit 7c40ab1

Browse files
authored
Add alloc_extend_kernel (#196)
1 parent 3085bab commit 7c40ab1

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from sgl_kernel_npu.utils.triton_utils import get_device_properties
5+
6+
7+
@triton.jit
8+
def alloc_extend_kernel(
9+
pre_lens_ptr,
10+
seq_lens_ptr,
11+
last_loc_ptr,
12+
free_page_ptr,
13+
out_indices,
14+
bs_upper: tl.constexpr,
15+
page_size: tl.constexpr,
16+
max_num_extend_tokens: tl.constexpr,
17+
BLOCK_SIZE: tl.constexpr = 2048,
18+
):
19+
pid = tl.program_id(0)
20+
21+
load_offset = tl.arange(0, bs_upper)
22+
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
23+
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
24+
extend_lens = seq_lens - pre_lens
25+
26+
seq_len = tl.load(seq_lens_ptr + pid)
27+
pre_len = tl.load(pre_lens_ptr + pid)
28+
extend_len = seq_len - pre_len
29+
30+
sum_extend_lens = tl.sum(extend_lens)
31+
output_start_loc = sum_extend_lens - extend_len
32+
33+
num_pages_after = (seq_lens + page_size - 1) // page_size
34+
num_pages_before = (pre_lens + page_size - 1) // page_size
35+
num_new_pages = num_pages_after - num_pages_before
36+
37+
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
38+
pre_len + page_size - 1
39+
) // page_size
40+
sum_num_new_pages = tl.sum(num_new_pages)
41+
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
42+
43+
# Part 1: fill the old partial page
44+
last_loc = tl.load(last_loc_ptr + pid)
45+
num_part1 = (
46+
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
47+
)
48+
offset_one_page = tl.arange(0, page_size)
49+
tl.store(
50+
out_indices + output_start_loc + offset_one_page,
51+
last_loc + 1 + offset_one_page,
52+
mask=offset_one_page < num_part1,
53+
)
54+
if pre_len + num_part1 == seq_len:
55+
return
56+
57+
# Part 2: fill the new full pages
58+
num_part2 = (
59+
seq_len // page_size * page_size
60+
- (pre_len + page_size - 1) // page_size * page_size
61+
)
62+
63+
num_loop = tl.cdiv(max_num_extend_tokens, BLOCK_SIZE)
64+
blk_offset = tl.arange(0, BLOCK_SIZE)
65+
for i in range(num_loop):
66+
offset_many_page = blk_offset + i * BLOCK_SIZE
67+
page_start = tl.load(
68+
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
69+
mask=offset_many_page < num_part2,
70+
)
71+
tl.store(
72+
out_indices + output_start_loc + num_part1 + offset_many_page,
73+
page_start * page_size + offset_many_page % page_size,
74+
mask=offset_many_page < num_part2,
75+
)
76+
77+
if pre_len + num_part1 + num_part2 == seq_len:
78+
return
79+
80+
# Part 3: fill the new partial page
81+
num_part3 = seq_len - seq_len // page_size * page_size
82+
start_loc = tl.load(
83+
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
84+
)
85+
tl.store(
86+
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
87+
start_loc * page_size + offset_one_page,
88+
mask=offset_one_page < num_part3,
89+
)

0 commit comments

Comments
 (0)