@@ -35,7 +35,7 @@ def ref_ragged_paged_attention(
3535 page_indices : jax .Array , # i32[max_num_seqs * pages_per_seq]
3636 cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
3737 distribution : jax .Array , # i32[3]
38- * ,
38+ attention_sink : jax . Array | None = None , # [actual_num_q_heads] *,
3939 sm_scale : float = 1.0 ,
4040 sliding_window : int | None = None ,
4141 soft_cap : float | None = None ,
@@ -143,7 +143,17 @@ def ref_ragged_paged_attention(
143143 if soft_cap is not None :
144144 attn = soft_cap * jnp .tanh (attn / soft_cap )
145145 attn += jnp .where (mask , mask_value , 0.0 )
146- attn = jax .nn .softmax (attn , axis = - 1 ).astype (v .dtype )
146+
147+ if attention_sink is not None :
148+ reshaped_attention_sink = attention_sink .reshape (actual_num_q_heads , 1 , 1 )
149+ reshaped_attention_sink = jnp .repeat (
150+ reshaped_attention_sink , q_len , axis = 1
151+ )
152+ attn = jnp .concat ([reshaped_attention_sink , attn ], axis = 2 )
153+ attn = jax .nn .softmax (attn , axis = - 1 ).astype (v .dtype )
154+ attn = attn [..., 1 :]
155+ else :
156+ attn = jax .nn .softmax (attn , axis = - 1 ).astype (v .dtype )
147157
148158 out = jnp .einsum ("hqk,khd->qhd" , attn , v ).astype (queries .dtype )
149159 if v_scale is not None :
@@ -232,6 +242,7 @@ def _ragged_paged_attention_kernel(
232242 sem_ids_ref , # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
233243 bo_ids_ref , # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
234244 bkv_update_ids_ref , # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
245+ attention_sink_ref , # [actual_num_kv_heads]
235246 # Input
236247 q_hbm_ref , # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
237248 kv_hbm_ref , # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
@@ -371,7 +382,21 @@ def load_with_init(ref, init_val):
371382 s = soft_cap * jnp .tanh (s / soft_cap )
372383 s += jnp .where (mask , mask_value , 0.0 )
373384 s_rowmax = jnp .max (s , axis = 1 , keepdims = True )
374- m_prev = load_with_init (head_m_ref , - jnp .inf )
385+
386+ if attention_sink_ref is not None :
387+ start_idx = kv_head_idx * num_q_heads_per_kv_head
388+ m_prevs = []
389+
390+ for i in range (num_q_heads_per_kv_head ):
391+ m_prev_init = attention_sink_ref [start_idx + i ]
392+ m_prevs .append (
393+ load_with_init (head_m_ref [i ::num_q_heads_per_kv_head ], m_prev_init )
394+ )
395+ m_prev = jnp .stack (m_prevs , axis = 1 ).reshape (head_m_ref .shape )
396+ else :
397+ m_prev_init = - jnp .inf
398+ m_prev = load_with_init (head_m_ref , m_prev_init )
399+
375400 m_curr = jnp .maximum (m_prev , s_rowmax )
376401 head_m_ref [...] = m_curr
377402 p = jnp .exp (s - broadcast_minor (m_curr , s .shape ))
@@ -382,7 +407,7 @@ def load_with_init(ref, init_val):
382407
383408 p_rowsum = jnp .sum (p , axis = 1 , keepdims = True )
384409 exp_m_diff = jnp .exp (m_prev - m_curr )
385- l_prev = load_with_init (head_l_ref , 0 .0 )
410+ l_prev = load_with_init (head_l_ref , 1 .0 )
386411 l_curr = exp_m_diff * l_prev + p_rowsum
387412 head_l_ref [...] = l_curr
388413 o_prev = load_with_init (head_acc_ref , 0.0 )
@@ -1255,6 +1280,7 @@ def ragged_paged_attention(
12551280 page_indices : jax .Array , # i32[max_num_seqs * pages_per_seq]
12561281 cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
12571282 distribution : jax .Array , # i32[3]
1283+ attention_sink : jax .Array | None = None , # f32[actual_num_q_heads]
12581284 * ,
12591285 sm_scale : float = 1.0 ,
12601286 sliding_window : int | None = None ,
@@ -1286,6 +1312,7 @@ def ragged_paged_attention(
12861312 distribution: (i, j, k) represents that sequences[0:i] are decode-only,
12871313 sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
12881314 k is also the total number of sequences.
1315+ attention_sink: optional attention sink for each kv head.
12891316 actual_head_dim: the actual head size of the attention. Here we assume k and
12901317 v have the same actual head size.
12911318 sm_scale: the softmax scale which will be applied to the Q@K^T.
@@ -1426,8 +1453,13 @@ def ragged_paged_attention(
14261453 jnp .full ((4 , ), - 1 , jnp .int32 ),
14271454 # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
14281455 jnp .full ((6 , ), - 1 , jnp .int32 ),
1456+ attention_sink ,
14291457 )
14301458
1459+ n_scalar_prefetches = len (scalar_prefetches )
1460+ if attention_sink is None :
1461+ n_scalar_prefetches -= 1
1462+
14311463 scope_name = f"RPA-bq_{ bq_sz } -bkvp_{ bkv_p } -p_{ page_size } "
14321464 kernel = jax .named_scope (scope_name )(
14331465 pl .pallas_call (
@@ -1464,8 +1496,8 @@ def ragged_paged_attention(
14641496 dtype = kv_cache .dtype ),
14651497 ],
14661498 input_output_aliases = {
1467- 7 : 0 ,
1468- 9 : 1
1499+ n_scalar_prefetches : 0 ,
1500+ n_scalar_prefetches + 2 : 1 ,
14691501 },
14701502 name = scope_name ,
14711503 ))
0 commit comments