Skip to content

Commit 1a1dfb5

Browse files
committed
[RPA] Implement attention sink
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent 6242fc3 commit 1a1dfb5

File tree

1 file changed

+38
-6
lines changed
  • tpu_inference/kernels/ragged_paged_attention/v3

1 file changed

+38
-6
lines changed

tpu_inference/kernels/ragged_paged_attention/v3/kernel.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def ref_ragged_paged_attention(
3535
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
3636
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
3737
distribution: jax.Array, # i32[3]
38-
*,
38+
attention_sink: jax.Array | None = None, # [actual_num_q_heads] *,
3939
sm_scale: float = 1.0,
4040
sliding_window: int | None = None,
4141
soft_cap: float | None = None,
@@ -143,7 +143,17 @@ def ref_ragged_paged_attention(
143143
if soft_cap is not None:
144144
attn = soft_cap * jnp.tanh(attn / soft_cap)
145145
attn += jnp.where(mask, mask_value, 0.0)
146-
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
146+
147+
if attention_sink is not None:
148+
reshaped_attention_sink = attention_sink.reshape(actual_num_q_heads, 1, 1)
149+
reshaped_attention_sink = jnp.repeat(
150+
reshaped_attention_sink, q_len, axis=1
151+
)
152+
attn = jnp.concat([reshaped_attention_sink, attn], axis=2)
153+
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
154+
attn = attn[..., 1:]
155+
else:
156+
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
147157

148158
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
149159
if v_scale is not None:
@@ -232,6 +242,7 @@ def _ragged_paged_attention_kernel(
232242
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
233243
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
234244
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
245+
attention_sink_ref, # [actual_num_kv_heads]
235246
# Input
236247
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
237248
kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
@@ -371,7 +382,21 @@ def load_with_init(ref, init_val):
371382
s = soft_cap * jnp.tanh(s / soft_cap)
372383
s += jnp.where(mask, mask_value, 0.0)
373384
s_rowmax = jnp.max(s, axis=1, keepdims=True)
374-
m_prev = load_with_init(head_m_ref, -jnp.inf)
385+
386+
if attention_sink_ref is not None:
387+
start_idx = kv_head_idx * num_q_heads_per_kv_head
388+
m_prevs = []
389+
390+
for i in range(num_q_heads_per_kv_head):
391+
m_prev_init = attention_sink_ref[start_idx + i]
392+
m_prevs.append(
393+
load_with_init(head_m_ref[i::num_q_heads_per_kv_head], m_prev_init)
394+
)
395+
m_prev = jnp.stack(m_prevs, axis=1).reshape(head_m_ref.shape)
396+
else:
397+
m_prev_init = -jnp.inf
398+
m_prev = load_with_init(head_m_ref, m_prev_init)
399+
375400
m_curr = jnp.maximum(m_prev, s_rowmax)
376401
head_m_ref[...] = m_curr
377402
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
@@ -382,7 +407,7 @@ def load_with_init(ref, init_val):
382407

383408
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
384409
exp_m_diff = jnp.exp(m_prev - m_curr)
385-
l_prev = load_with_init(head_l_ref, 0.0)
410+
l_prev = load_with_init(head_l_ref, 1.0)
386411
l_curr = exp_m_diff * l_prev + p_rowsum
387412
head_l_ref[...] = l_curr
388413
o_prev = load_with_init(head_acc_ref, 0.0)
@@ -1255,6 +1280,7 @@ def ragged_paged_attention(
12551280
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
12561281
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
12571282
distribution: jax.Array, # i32[3]
1283+
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
12581284
*,
12591285
sm_scale: float = 1.0,
12601286
sliding_window: int | None = None,
@@ -1286,6 +1312,7 @@ def ragged_paged_attention(
12861312
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
12871313
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
12881314
k is also the total number of sequences.
1315+
attention_sink: optional attention sink for each kv head.
12891316
actual_head_dim: the actual head size of the attention. Here we assume k and
12901317
v have the same actual head size.
12911318
sm_scale: the softmax scale which will be applied to the Q@K^T.
@@ -1426,8 +1453,13 @@ def ragged_paged_attention(
14261453
jnp.full((4, ), -1, jnp.int32),
14271454
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
14281455
jnp.full((6, ), -1, jnp.int32),
1456+
attention_sink,
14291457
)
14301458

1459+
n_scalar_prefetches = len(scalar_prefetches)
1460+
if attention_sink is None:
1461+
n_scalar_prefetches -= 1
1462+
14311463
scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
14321464
kernel = jax.named_scope(scope_name)(
14331465
pl.pallas_call(
@@ -1464,8 +1496,8 @@ def ragged_paged_attention(
14641496
dtype=kv_cache.dtype),
14651497
],
14661498
input_output_aliases={
1467-
7: 0,
1468-
9: 1
1499+
n_scalar_prefetches: 0,
1500+
n_scalar_prefetches + 2: 1,
14691501
},
14701502
name=scope_name,
14711503
))

0 commit comments

Comments
 (0)