|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | +from sgl_kernel_npu.utils.triton_utils import get_device_properties |
| 5 | + |
| 6 | + |
| 7 | +@triton.jit |
| 8 | +def alloc_extend_kernel( |
| 9 | + pre_lens_ptr, |
| 10 | + seq_lens_ptr, |
| 11 | + last_loc_ptr, |
| 12 | + free_page_ptr, |
| 13 | + out_indices, |
| 14 | + bs_upper: tl.constexpr, |
| 15 | + page_size: tl.constexpr, |
| 16 | + max_num_extend_tokens: tl.constexpr, |
| 17 | + BLOCK_SIZE: tl.constexpr = 2048, |
| 18 | +): |
| 19 | + pid = tl.program_id(0) |
| 20 | + |
| 21 | + load_offset = tl.arange(0, bs_upper) |
| 22 | + seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) |
| 23 | + pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid) |
| 24 | + extend_lens = seq_lens - pre_lens |
| 25 | + |
| 26 | + seq_len = tl.load(seq_lens_ptr + pid) |
| 27 | + pre_len = tl.load(pre_lens_ptr + pid) |
| 28 | + extend_len = seq_len - pre_len |
| 29 | + |
| 30 | + sum_extend_lens = tl.sum(extend_lens) |
| 31 | + output_start_loc = sum_extend_lens - extend_len |
| 32 | + |
| 33 | + num_pages_after = (seq_lens + page_size - 1) // page_size |
| 34 | + num_pages_before = (pre_lens + page_size - 1) // page_size |
| 35 | + num_new_pages = num_pages_after - num_pages_before |
| 36 | + |
| 37 | + num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( |
| 38 | + pre_len + page_size - 1 |
| 39 | + ) // page_size |
| 40 | + sum_num_new_pages = tl.sum(num_new_pages) |
| 41 | + new_page_start_loc = sum_num_new_pages - num_page_start_loc_self |
| 42 | + |
| 43 | + # Part 1: fill the old partial page |
| 44 | + last_loc = tl.load(last_loc_ptr + pid) |
| 45 | + num_part1 = ( |
| 46 | + min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len |
| 47 | + ) |
| 48 | + offset_one_page = tl.arange(0, page_size) |
| 49 | + tl.store( |
| 50 | + out_indices + output_start_loc + offset_one_page, |
| 51 | + last_loc + 1 + offset_one_page, |
| 52 | + mask=offset_one_page < num_part1, |
| 53 | + ) |
| 54 | + if pre_len + num_part1 == seq_len: |
| 55 | + return |
| 56 | + |
| 57 | + # Part 2: fill the new full pages |
| 58 | + num_part2 = ( |
| 59 | + seq_len // page_size * page_size |
| 60 | + - (pre_len + page_size - 1) // page_size * page_size |
| 61 | + ) |
| 62 | + |
| 63 | + num_loop = tl.cdiv(max_num_extend_tokens, BLOCK_SIZE) |
| 64 | + blk_offset = tl.arange(0, BLOCK_SIZE) |
| 65 | + for i in range(num_loop): |
| 66 | + offset_many_page = blk_offset + i * BLOCK_SIZE |
| 67 | + page_start = tl.load( |
| 68 | + free_page_ptr + new_page_start_loc + offset_many_page // page_size, |
| 69 | + mask=offset_many_page < num_part2, |
| 70 | + ) |
| 71 | + tl.store( |
| 72 | + out_indices + output_start_loc + num_part1 + offset_many_page, |
| 73 | + page_start * page_size + offset_many_page % page_size, |
| 74 | + mask=offset_many_page < num_part2, |
| 75 | + ) |
| 76 | + |
| 77 | + if pre_len + num_part1 + num_part2 == seq_len: |
| 78 | + return |
| 79 | + |
| 80 | + # Part 3: fill the new partial page |
| 81 | + num_part3 = seq_len - seq_len // page_size * page_size |
| 82 | + start_loc = tl.load( |
| 83 | + free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1 |
| 84 | + ) |
| 85 | + tl.store( |
| 86 | + out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page, |
| 87 | + start_loc * page_size + offset_one_page, |
| 88 | + mask=offset_one_page < num_part3, |
| 89 | + ) |
0 commit comments