99from jax .sharding import Mesh
1010from jax .sharding import PartitionSpec as P
1111
12- from tpu_inference .kernels .ragged_paged_attention .v3 .kernel import \
13- ragged_paged_attention
12+ from tpu_inference .kernels .ragged_paged_attention .v3 .kernel_hd64 import \
13+ ragged_paged_attention_hd64
1414from tpu_inference .layers .common .attention_metadata import AttentionMetadata
1515from tpu_inference .layers .jax .base import create_param
1616from tpu_inference .layers .jax .rope import GptOssRotaryEmbedding
@@ -155,13 +155,13 @@ def attention(
155155 P (), # page_indices_flat: Replicated
156156 P (), # query_start_loc: Replicated
157157 P (), # distribution: Replicated
158- # P(('model')), # sinks
158+ P (('model' )), # sinks
159159 )
160160 out_specs = (self .attn_o_tnh , kv_cache_spec )
161161
162162 def _ragged_paged_attention_wrapper (* args ):
163163 # Pass the GPT-OSS specific parameters to the kernel
164- return ragged_paged_attention (
164+ return ragged_paged_attention_hd64 (
165165 * args ,
166166 sm_scale = self .sm_scale ,
167167 sliding_window = md .sliding_window ,
@@ -183,7 +183,7 @@ def _ragged_paged_attention_wrapper(*args):
183183 md .block_tables ,
184184 md .query_start_loc ,
185185 md .request_distribution ,
186- # sinks,
186+ sinks ,
187187 )
188188 return kv_cache , output_TNH
189189
@@ -213,14 +213,6 @@ def __call__(self,
213213 q_TNH , k_TKH = self .rope (q_TNH , k_TKH , md .input_positions )
214214
215215 with jax .named_scope ("attn_op" ):
216- # Padding H dim of q,k,v to be the multiple of 128
217- multiple_of_128 = ((self .head_dim - 1 ) // 128 + 1 ) * 128
218- q_TNH = jnp .pad (q_TNH , ((0 , 0 ), (0 , 0 ),
219- (0 , multiple_of_128 - self .head_dim )))
220- k_TKH = jnp .pad (k_TKH , ((0 , 0 ), (0 , 0 ),
221- (0 , multiple_of_128 - self .head_dim )))
222- v_TKH = jnp .pad (v_TKH , ((0 , 0 ), (0 , 0 ),
223- (0 , multiple_of_128 - self .head_dim )))
224216 new_kv_cache , attn_out_TNH = self .attention (
225217 kv_cache , q_TNH , k_TKH , v_TKH , self .sinks_N .value , md ,
226218 self .mesh )
0 commit comments