@@ -217,30 +217,30 @@ def prepare_inputs(
217217 # Update seq_lens for active requests only: new_seq_lens = s - n.
218218 new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
219219
220- return self ._prepare_inputs_in_jit (token_indices_cpu ,
221- new_query_start_loc_cpu ,
222- new_seq_lens_cpu , input_ids ,
220+ token_indices = jnp .asarray (token_indices_cpu , dtype = jnp .int32 )
221+ query_start_loc , seq_lens = device_array (self .mesh , (
222+ new_query_start_loc_cpu ,
223+ new_seq_lens_cpu ,
224+ ))
225+
226+ return self ._prepare_inputs_in_jit (token_indices , query_start_loc ,
227+ seq_lens , input_ids ,
223228 aux_hidden_states , attn_metadata ,
224229 next_token_ids , block_tables )
225230
226231 @functools .partial (jax .jit , static_argnums = (0 , ))
227232 def _prepare_inputs_in_jit (
228233 self ,
229- token_indices_cpu : np . ndarray ,
230- new_query_start_loc_cpu : np . ndarray ,
231- new_seq_lens_cpu : np . ndarray ,
234+ token_indices : jax . Array ,
235+ query_start_loc : jax . Array ,
236+ seq_lens : jax . Array ,
232237 input_ids : jax .Array ,
233238 aux_hidden_states : tuple [jax .Array , ...],
234239 attn_metadata : AttentionMetadata ,
235240 next_token_ids : jax .Array ,
236241 block_tables : jax .Array ,
237242 ) -> tuple [jax .Array , jax .Array , jax .Array , AttentionMetadata ]:
238243
239- token_indices = jnp .asarray (token_indices_cpu , dtype = jnp .int32 )
240- query_start_loc , seq_lens = device_array (self .mesh , (
241- new_query_start_loc_cpu ,
242- new_seq_lens_cpu ,
243- ))
244244 # Select tokens and hidden states.
245245 target_token_ids = input_ids [token_indices ]
246246 target_hidden_states = jnp .concatenate (
@@ -259,7 +259,7 @@ def _prepare_inputs_in_jit(
259259 request_distribution = attn_metadata .request_distribution ,
260260 )
261261
262- target_hidden_states , input_ids , last_token_indices , block_tables = self ._prepare_draft_inputs (
262+ target_hidden_states , input_ids , last_token_indices , block_tables = self ._prepare_draft_inputs_in_jit (
263263 target_hidden_states , query_start_loc , target_token_ids ,
264264 next_token_ids , block_tables )
265265
@@ -272,16 +272,6 @@ def _prepare_draft_inputs_in_jit(
272272 target_token_ids : jax .Array , next_token_ids : jax .Array ,
273273 block_tables : jax .Array
274274 ) -> tuple [jax .Array , jax .Array , jax .Array , jax .Array ]:
275- return self ._prepare_draft_inputs (target_hidden_states ,
276- query_start_loc , target_token_ids ,
277- next_token_ids , block_tables )
278-
279- def _prepare_draft_inputs (
280- self , target_hidden_states : jax .Array , query_start_loc : jax .Array ,
281- target_token_ids : jax .Array , next_token_ids : jax .Array ,
282- block_tables : jax .Array
283- ) -> tuple [jax .Array , jax .Array , jax .Array , jax .Array ]:
284-
285275 target_hidden_states = self .combine_hidden_states_fn (
286276 self .state , target_hidden_states )
287277
@@ -294,13 +284,13 @@ def _prepare_draft_inputs(
294284
295285 return target_hidden_states , input_ids , last_token_indices , block_tables
296286
297- @functools .partial (jax .jit , static_argnums = (0 , ))
298- def _select_draft_token_ids_in_jit (
287+ def _select_draft_token_ids (
299288 self ,
300289 hidden_states : jax .Array ,
301290 last_token_indices : jax .Array ,
302291 ) -> jax .Array :
303- return self ._select_draft_token_ids (hidden_states , last_token_indices )
292+ sample_hidden_states = hidden_states [last_token_indices ]
293+ return self ._get_draft_token_ids_in_jit (sample_hidden_states )
304294
305295 @functools .partial (jax .jit , static_argnums = (0 , ))
306296 def _select_and_stack_draft_token_ids_in_jit (
@@ -312,30 +302,22 @@ def _select_and_stack_draft_token_ids_in_jit(
312302 last_token_indices )
313303 return jnp .stack ([draft_token_ids ], axis = 1 )
314304
315- def _select_draft_token_ids (
316- self ,
317- hidden_states : jax .Array ,
318- last_token_indices : jax .Array ,
319- ) -> jax .Array :
320- sample_hidden_states = hidden_states [last_token_indices ]
321- return self ._get_draft_token_ids (sample_hidden_states )
322-
323- def _get_draft_token_ids (self , hidden_states : jax .Array ) -> jax .Array :
305+ @functools .partial (jax .jit , static_argnums = (0 , ))
306+ def _get_draft_token_ids_in_jit (self ,
307+ hidden_states : jax .Array ) -> jax .Array :
324308 lora_metadata = None
325309 logits = self .compute_logits_fn (self .state , hidden_states ,
326310 lora_metadata )
327311 return jnp .argmax (logits , axis = - 1 )
328312
329313 @functools .partial (jax .jit , static_argnums = (0 , ))
330- def _get_draft_token_ids_in_jit (self ,
331- hidden_states : jax .Array ) -> jax .Array :
332- return self ._get_draft_token_ids (hidden_states )
333-
334- @functools .partial (jax .jit , static_argnums = (0 , ))
335- def _select_positions_and_hidden_states_in_jit (
336- self , positions : jax .Array , hidden_states : jax .Array ,
314+ def _select_inputs_for_loop_in_jit (
315+ self , positions : jax .Array , residual : jax .Array ,
316+ hidden_states : jax .Array ,
337317 last_token_indices : jax .Array ) -> tuple [jax .Array , jax .Array ]:
338- return positions [last_token_indices ], hidden_states [last_token_indices ]
318+ return positions [last_token_indices ], residual [
319+ last_token_indices ], self ._select_draft_token_ids (
320+ hidden_states , last_token_indices )
339321
340322 def propose (
341323 self ,
@@ -364,14 +346,12 @@ def propose(
364346 return kv_caches , self ._select_and_stack_draft_token_ids_in_jit (
365347 hidden_states , last_token_indices )
366348
367- draft_token_ids = self ._select_draft_token_ids_in_jit (
368- hidden_states , last_token_indices )
349+ positions , hidden_states , draft_token_ids = self ._select_inputs_for_loop_in_jit (
350+ attn_metadata .input_positions , residual [0 ], hidden_states ,
351+ last_token_indices )
369352
370353 draft_token_ids_list = [draft_token_ids ]
371354
372- positions , hidden_states = self ._select_positions_and_hidden_states_in_jit (
373- attn_metadata .input_positions , residual [0 ], last_token_indices )
374-
375355 for _ in range (self .num_speculative_tokens - 1 ):
376356 input_ids_loop = draft_token_ids_list [- 1 ]
377357
0 commit comments