From d873614b8f01add5ab1e2f1e485e84ac9632dae6 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Wed, 22 Oct 2025 21:05:45 +0800 Subject: [PATCH 1/2] Support CUDAGraph Padding + MTP --- .../gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 0d5cd88180f..6a7f5c4db32 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -627,6 +627,7 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; From 0661ebe60016e1f05b38da187c23147986dd2cfe Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Thu, 23 Oct 2025 11:20:23 +0800 Subject: [PATCH 2/2] support orther write cache kernel --- .../speculate_write_cache_with_rope_impl.cuh | 163 ++++++++++-------- 1 file changed, 95 insertions(+), 68 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 6a7f5c4db32..56026157a17 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -27,7 +27,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, // head_size // 2] T* __restrict__ q_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] @@ -68,10 +68,13 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; const int half_head_size = head_size / 2; - for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) { + for (int global_hi = global_warp_idx; global_hi < all_head_dim; + global_hi += all_warp_num) { int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize; const int token_id = linear_index / hidden_size; + const int ori_bi = batch_id_per_token[token_id]; + if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding if (seq_lens_decoder[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v @@ -84,7 +87,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - return ; // NOTE(gongshaotian): For CUDAGraph padding + return; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -102,7 +105,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); } @@ -136,20 +140,23 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( } if (hi < (num_heads + gqa_group_size)) { WelfordWarpAllReduce(thread_m2, &warp_m2); - float row_variance = - max(warp_m2 / head_size, 0.0f); + float row_variance = max(warp_m2 / head_size, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); if (hi < num_heads) { - Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); - #pragma unroll + Load(&q_norm_weight[threadIdx.x * VecSize], + &q_norm_vec); +#pragma unroll for (int i = 0; i < VecSize; i++) { - bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + bias_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); } } else { - Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); - #pragma unroll + Load(&k_norm_weight[threadIdx.x * VecSize], + &k_norm_vec); +#pragma unroll for (int i = 0; i < VecSize; i++) { - bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + bias_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); } } } @@ -179,7 +186,7 @@ __global__ void append_clear_cache_int8_block( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] const int* __restrict__ seq_lens, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_encoder, // [bsz] @@ -197,6 +204,7 @@ __global__ void append_clear_cache_int8_block( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -245,7 +253,6 @@ __global__ void append_clear_cache_int8_block( } } - template __global__ void append_clear_cache_int4_block( uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, @@ -253,7 +260,7 @@ __global__ void append_clear_cache_int4_block( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] const int* __restrict__ seq_lens, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_encoder, // [bsz] @@ -271,6 +278,7 @@ __global__ void append_clear_cache_int4_block( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -331,7 +339,7 @@ __global__ void append_speculate_cache_rope_kernel( T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, // head_size // 2] T* __restrict__ q_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] @@ -370,6 +378,8 @@ __global__ void append_speculate_cache_rope_kernel( linear_index += step) { const int token_id = linear_index / hidden_size; const int ori_bi = batch_id_per_token[token_id]; + if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding + if (seq_lens_decoder[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v @@ -382,7 +392,7 @@ __global__ void append_speculate_cache_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - return ; // NOTE(gongshaotian): For CUDAGraph padding + return; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -400,7 +410,8 @@ __global__ void append_speculate_cache_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); } @@ -458,7 +469,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, // head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] @@ -497,6 +508,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( linear_index += step) { const int token_id = linear_index / half_hidden_size; const int ori_bi = batch_id_per_token[token_id]; + if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding if (seq_lens_decoder[ori_bi] == 0) continue; const int bias = linear_index % half_hidden_size; const int hi = bias / half_head_size; // q + k + v @@ -509,7 +521,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - return ; // NOTE(gongshaotian): For CUDAGraph padding + return; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -531,7 +543,8 @@ __global__ void append_speculate_cache_neox_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * head_size + h_bias; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); } @@ -591,14 +604,14 @@ template __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( - const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 * + const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 * // gqa_group_size, head_size] uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -627,7 +640,7 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; - if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -645,10 +658,12 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( if (head_idx < num_heads) { cache_offset = 0; } else if (head_idx < num_heads + 2 * gqa_group_size) { - cache_offset = block_idx * gqa_group_size * block_size + (head_idx - num_heads) % gqa_group_size * block_size + block_offset; + cache_offset = block_idx * gqa_group_size * block_size + + (head_idx - num_heads) % gqa_group_size * block_size + + block_offset; } - T *cache_k_scale_now = cache_k_scale + cache_offset; - T *cache_v_scale_now = cache_v_scale + cache_offset; + T* cache_k_scale_now = cache_k_scale + cache_offset; + T* cache_v_scale_now = cache_v_scale + cache_offset; float thread_m2 = 0.0f; float warp_m2 = 0.0f; @@ -676,7 +691,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll @@ -689,22 +705,20 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( float tmp1 = input_left * cos_tmp - input_right * sin_tmp; float tmp2 = input_right * cos_tmp + input_left * sin_tmp; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; - bias_vec[2 * i] = - static_cast(tmp1); - bias_vec[2 * i + 1] = - static_cast(tmp2); + bias_vec[2 * i] = static_cast(tmp1); + bias_vec[2 * i + 1] = static_cast(tmp2); } // qk norm if (q_norm_weight) { WelfordWarpAllReduce(thread_m2, &warp_m2); - float row_variance = - max(warp_m2 / HeadDim, 0.0f); + float row_variance = max(warp_m2 / HeadDim, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); LoadOutScaleT q_norm_vec; Load(&q_norm_weight[lane_id * VecSize], &q_norm_vec); - #pragma unroll +#pragma unroll for (int i = 0; i < VecSize; i++) { - bias_vec[i] = static_cast(static_cast(bias_vec[i]) * row_inv_var * q_norm_vec[i]); + bias_vec[i] = static_cast(static_cast(bias_vec[i]) * + row_inv_var * q_norm_vec[i]); } } Store(bias_vec, &qkv_out_now[bias_idx]); @@ -740,7 +754,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( const int v_head_idx = head_idx - num_heads - gqa_group_size; if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec1); Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); Load(&sin_emb[new_emb_idx], &sin_emb_vec1); @@ -755,10 +770,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( float tmp1 = input_left * cos_tmp - input_right * sin_tmp; float tmp2 = input_right * cos_tmp + input_left * sin_tmp; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; - bias_vec1[0] = - static_cast(tmp1); - bias_vec1[1] = - static_cast(tmp2); + bias_vec1[0] = static_cast(tmp1); + bias_vec1[1] = static_cast(tmp2); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -772,10 +785,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( float tmp1 = input_left * cos_tmp - input_right * sin_tmp; float tmp2 = input_right * cos_tmp + input_left * sin_tmp; thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; - bias_vec2[0] = - static_cast(tmp1); - bias_vec2[1] = - static_cast(tmp2); + bias_vec2[0] = static_cast(tmp1); + bias_vec2[1] = static_cast(tmp2); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); @@ -784,16 +795,18 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( if (head_idx < num_heads + gqa_group_size) { LoadOutScaleT k_norm_vec1, k_norm_vec2; Load(&k_norm_weight[head_bias], &k_norm_vec1); - Load(&k_norm_weight[head_bias + 8], &k_norm_vec2); + Load(&k_norm_weight[head_bias + 8], + &k_norm_vec2); // qk norm WelfordWarpAllReduce(thread_m2, &warp_m2); - float row_variance = - max(warp_m2 / HeadDim, 0.0f); + float row_variance = max(warp_m2 / HeadDim, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); for (int i = 0; i < HALF_K_VEC_SIZE; i++) { - bias_vec1[i] = static_cast(static_cast(bias_vec1[i]) * row_inv_var * k_norm_vec1[i]); - bias_vec2[i] = static_cast(static_cast(bias_vec2[i]) * row_inv_var * k_norm_vec2[i]); + bias_vec1[i] = static_cast(static_cast(bias_vec1[i]) * + row_inv_var * k_norm_vec1[i]); + bias_vec2[i] = static_cast(static_cast(bias_vec2[i]) * + row_inv_var * k_norm_vec2[i]); } } } @@ -806,7 +819,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( } #pragma unroll for (int m_offset = 16; m_offset > 0; m_offset /= 2) { - local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + local_max = + __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); } scale = __hdiv(448, local_max); @@ -821,8 +835,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( #pragma unroll for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { - cache_vec[i] = QuantToC8(scale, bias_vec1[i], max_bound, min_bound); - cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8(scale, bias_vec2[i], max_bound, min_bound); + cache_vec[i] = QuantToC8( + scale, bias_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, bias_vec2[i], max_bound, min_bound); } if (head_idx < num_heads + gqa_group_size) { const int start_block_16 = @@ -867,7 +883,7 @@ __global__ void append_speculate_cache_int8_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -897,6 +913,7 @@ __global__ void append_speculate_cache_int8_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -933,7 +950,8 @@ __global__ void append_speculate_cache_int8_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { @@ -995,7 +1013,8 @@ __global__ void append_speculate_cache_int8_rope_kernel( T scale; if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec1); Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); Load(&sin_emb[new_emb_idx], &sin_emb_vec1); @@ -1057,8 +1076,10 @@ __global__ void append_speculate_cache_int8_rope_kernel( } #pragma unroll for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { - cache_vec[i] = QuantToC8(scale, bias_vec1[i], max_bound, min_bound); - cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8(scale, bias_vec2[i], max_bound, min_bound); + cache_vec[i] = QuantToC8( + scale, bias_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, bias_vec2[i], max_bound, min_bound); } if (head_idx < num_heads + gqa_group_size) { const int start_block_16 = @@ -1102,7 +1123,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -1132,6 +1153,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -1171,7 +1193,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { @@ -1268,7 +1291,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( T scale; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec1); Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); Load(&sin_emb[new_emb_idx], &sin_emb_vec1); @@ -1483,7 +1507,7 @@ __global__ void append_speculate_cache_int4_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -1516,6 +1540,7 @@ __global__ void append_speculate_cache_int4_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -1562,7 +1587,8 @@ __global__ void append_speculate_cache_int4_rope_kernel( // Load(&qkv_out_scales[bias_idx], &out_scale_vec); // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll @@ -1653,7 +1679,8 @@ __global__ void append_speculate_cache_int4_rope_kernel( // &out_scale_vec2); if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec1); Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); Load(&sin_emb[new_emb_idx], &sin_emb_vec1); @@ -1760,7 +1787,6 @@ __global__ void append_speculate_cache_int4_rope_kernel( } Store(cache_vec, &key_cache[tgt_cache_idx]); } else { - const uint32_t base_tgt_cache_idx = block_idx * gqa_group_size * HeadDim * half_block_size + kv_head_idx * HeadDim * half_block_size + @@ -1829,7 +1855,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -1862,6 +1888,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -2001,7 +2028,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( &right_out_scale_vec2); const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec1); Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); Load(&sin_emb[new_emb_idx], &sin_emb_vec1); @@ -2039,7 +2067,6 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( right_bias_vec1[i] = static_cast(input_right * cos_tmp + input_left * sin_tmp); - input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i];