Skip to content

Commit d989dc3

Browse files
bzgooglebzgoogle
andauthored
[GPT-OSS] uncomment sink related changes as the kernel_hd64.py was merged (#966)
Signed-off-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal> Co-authored-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal>
1 parent 03d76de commit d989dc3

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

tpu_inference/layers/jax/attention/gpt_oss_attention.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from jax.sharding import Mesh
1010
from jax.sharding import PartitionSpec as P
1111

12-
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
13-
ragged_paged_attention
12+
from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
13+
ragged_paged_attention_hd64
1414
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1515
from tpu_inference.layers.jax.base import create_param
1616
from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
@@ -155,13 +155,13 @@ def attention(
155155
P(), # page_indices_flat: Replicated
156156
P(), # query_start_loc: Replicated
157157
P(), # distribution: Replicated
158-
#P(('model')), # sinks
158+
P(('model')), # sinks
159159
)
160160
out_specs = (self.attn_o_tnh, kv_cache_spec)
161161

162162
def _ragged_paged_attention_wrapper(*args):
163163
# Pass the GPT-OSS specific parameters to the kernel
164-
return ragged_paged_attention(
164+
return ragged_paged_attention_hd64(
165165
*args,
166166
sm_scale=self.sm_scale,
167167
sliding_window=md.sliding_window,
@@ -183,7 +183,7 @@ def _ragged_paged_attention_wrapper(*args):
183183
md.block_tables,
184184
md.query_start_loc,
185185
md.request_distribution,
186-
#sinks,
186+
sinks,
187187
)
188188
return kv_cache, output_TNH
189189

@@ -213,14 +213,6 @@ def __call__(self,
213213
q_TNH, k_TKH = self.rope(q_TNH, k_TKH, md.input_positions)
214214

215215
with jax.named_scope("attn_op"):
216-
# Padding H dim of q,k,v to be the multiple of 128
217-
multiple_of_128 = ((self.head_dim - 1) // 128 + 1) * 128
218-
q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
219-
(0, multiple_of_128 - self.head_dim)))
220-
k_TKH = jnp.pad(k_TKH, ((0, 0), (0, 0),
221-
(0, multiple_of_128 - self.head_dim)))
222-
v_TKH = jnp.pad(v_TKH, ((0, 0), (0, 0),
223-
(0, multiple_of_128 - self.head_dim)))
224216
new_kv_cache, attn_out_TNH = self.attention(
225217
kv_cache, q_TNH, k_TKH, v_TKH, self.sinks_N.value, md,
226218
self.mesh)

0 commit comments

Comments
 (0)