@@ -125,9 +125,9 @@ def _apply_mask(
125125 masks = []
126126 if mask_ref is not None :
127127 if k_in_lanes :
128- mask = pl . load ( mask_ref , ( slice ( None ), k_slice ))
128+ mask = mask_ref [:, k_slice ]
129129 else :
130- mask = pl . load ( mask_ref , ( k_slice , slice ( None )))
130+ mask = mask_ref [ k_slice , :]
131131
132132 snm = jnp .where (should_not_mask , 1 , 0 )
133133 masks .append (jnp .bitwise_or (mask , jnp .broadcast_to (snm , mask .shape )) != 0 )
@@ -156,7 +156,7 @@ def _apply_mask(
156156 k_sequence = k_offset + jax .lax .broadcasted_iota (
157157 jnp .int32 , (k_slice .size , bq ), 0
158158 )
159- q_sequence = pl . load ( q_sequence_ref , ( pl . ds ( 1 ), slice ( None ))) # [1, bq]
159+ q_sequence = q_sequence_ref [: 1 , :] # [1, bq]
160160 q_sequence = jnp .broadcast_to (q_sequence , (k_slice .size , bq ))
161161
162162 assert q_sequence .shape == k_sequence .shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170
171171 if q_segment_ids_ref is not None :
172172 if k_in_lanes :
173- kv_ids = pl . load ( kv_segment_ids_ref , ( pl . ds ( 1 ) , k_slice )) # [1, k_slice]
173+ kv_ids = kv_segment_ids_ref [: 1 , k_slice ] # [1, k_slice]
174174 repeats , rem = divmod (kv_ids .shape [1 ], NUM_LANES )
175175 if rem :
176176 raise NotImplementedError (f"block_kv must be a multiple of { NUM_LANES } " )
@@ -181,9 +181,9 @@ def _apply_mask(
181181 if rem :
182182 raise NotImplementedError (f"block_q must be a multiple of { NUM_LANES } " )
183183 kv_ids = pltpu .repeat (
184- pl . load ( kv_segment_ids_ref , ( k_slice , slice ( None ))) , repeats , axis = 1
184+ kv_segment_ids_ref [ k_slice , :] , repeats , axis = 1
185185 ) # [k_slice, bq]
186- q_ids = pl . load ( q_segment_ids_ref , ( pl . ds ( 1 ), slice ( None ))) # [1, bq]
186+ q_ids = q_segment_ids_ref [: 1 , :] # [1, bq]
187187 masks .append (q_ids == kv_ids )
188188
189189 if masks :
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228 slice_k = pl .ds (kv_compute_index * bkv_compute , bkv_compute )
229229
230230 q = q_ref [...]
231- k = pl . load ( k_ref , ( slice_k , slice ( None )))
231+ k = k_ref [ slice_k , :]
232232 qk = jax .lax .dot_general (
233233 q , k , NT_DIM_NUMBERS , preferred_element_type = jnp .float32
234234 )
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256 )
257257
258258 sv_dims = NN_DIM_NUMBERS
259- v = pl . load ( v_ref , ( slice_k , slice ( None )))
259+ v = v_ref [ slice_k , :]
260260
261261 to_float32 = lambda x : x .astype (jnp .float32 )
262262 v = to_float32 (v )
0 commit comments