@@ -61,10 +61,11 @@ def _concate_hidden_states(
6161 # Concat aux hidden states along feature dim.
6262 return jnp .concatenate (aux_hidden_states , axis = - 1 )
6363
64- def _prepare_input_ids (self , query_start_loc : jax .Array ,
65- target_token_ids : jax .Array ,
66- next_token_ids : jax .Array ,
67- num_reqs : int ) -> tuple [jnp .ndarray , jnp .ndarray ]:
64+ @functools .partial (jax .jit , static_argnums = (0 , ))
65+ def _prepare_input_ids (
66+ self , query_start_loc : jax .Array , target_token_ids : jax .Array ,
67+ next_token_ids : jax .Array ,
68+ num_reqs : jax .Array ) -> tuple [jnp .ndarray , jnp .ndarray ]:
6869 """JIT-compiled helper for preparing the input IDs for the draft model."""
6970
7071 last_token_indices = query_start_loc [1 :] - 1
@@ -152,17 +153,17 @@ def prepare_inputs(
152153 draft_kv_cache_group_id = num_kv_cache_groups - 1
153154 block_tables = self .runner .input_batch .block_table [
154155 draft_kv_cache_group_id ].get_device_tensor ()
156+ # Number of active requests in this step (un-padded count).
157+ num_reqs = self .runner .input_batch .num_reqs
158+
155159 if num_rejected_tokens is None :
156160 target_hidden_states , input_ids , last_token_indices , block_tables = self ._prepare_draft_inputs_in_jit (
157161 self ._concate_hidden_states (aux_hidden_states ),
158162 attn_metadata .query_start_loc , input_ids , next_token_ids ,
159- block_tables )
163+ block_tables , jnp . asarray ([ num_reqs ], dtype = jnp . int32 ) )
160164 attn_metadata = replace (attn_metadata , block_tables = block_tables )
161165 return target_hidden_states , input_ids , last_token_indices , attn_metadata
162166
163- # Number of active requests in this step (un-padded count).
164- num_reqs = self .runner .input_batch .num_reqs
165-
166167 # Host copies from the metadata prepared by the runner.
167168 query_start_loc_cpu = attn_metadata .query_start_loc_cpu
168169 seq_lens_cpu = attn_metadata .seq_lens_cpu
@@ -217,16 +218,14 @@ def prepare_inputs(
217218 # Update seq_lens for active requests only: new_seq_lens = s - n.
218219 new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
219220
220- token_indices = jnp .asarray (token_indices_cpu , dtype = jnp .int32 )
221- query_start_loc , seq_lens = device_array (self .mesh , (
222- new_query_start_loc_cpu ,
223- new_seq_lens_cpu ,
224- ))
221+ query_start_loc , seq_lens , token_indices = device_array (
222+ self .mesh ,
223+ (new_query_start_loc_cpu , new_seq_lens_cpu , token_indices_cpu ))
225224
226- return self ._prepare_inputs_in_jit (token_indices , query_start_loc ,
227- seq_lens , input_ids ,
228- aux_hidden_states , attn_metadata ,
229- next_token_ids , block_tables )
225+ return self ._prepare_inputs_in_jit (
226+ token_indices , query_start_loc , seq_lens , input_ids ,
227+ aux_hidden_states , attn_metadata , next_token_ids , block_tables ,
228+ jnp . asarray ([ num_reqs ], dtype = jnp . int32 ) )
230229
231230 @functools .partial (jax .jit , static_argnums = (0 , ))
232231 def _prepare_inputs_in_jit (
@@ -239,6 +238,7 @@ def _prepare_inputs_in_jit(
239238 attn_metadata : AttentionMetadata ,
240239 next_token_ids : jax .Array ,
241240 block_tables : jax .Array ,
241+ num_reqs : jax .Array ,
242242 ) -> tuple [jax .Array , jax .Array , jax .Array , AttentionMetadata ]:
243243
244244 # Select tokens and hidden states.
@@ -261,23 +261,26 @@ def _prepare_inputs_in_jit(
261261
262262 target_hidden_states , input_ids , last_token_indices , block_tables = self ._prepare_draft_inputs_in_jit (
263263 target_hidden_states , query_start_loc , target_token_ids ,
264- next_token_ids , block_tables )
264+ next_token_ids , block_tables , num_reqs )
265265
266266 attn_metadata = replace (attn_metadata , block_tables = block_tables )
267267 return target_hidden_states , input_ids , last_token_indices , attn_metadata
268268
269269 @functools .partial (jax .jit , static_argnums = (0 , ))
270270 def _prepare_draft_inputs_in_jit (
271- self , target_hidden_states : jax .Array , query_start_loc : jax .Array ,
272- target_token_ids : jax .Array , next_token_ids : jax .Array ,
273- block_tables : jax .Array
271+ self ,
272+ target_hidden_states : jax .Array ,
273+ query_start_loc : jax .Array ,
274+ target_token_ids : jax .Array ,
275+ next_token_ids : jax .Array ,
276+ block_tables : jax .Array ,
277+ num_reqs : jax .Array ,
274278 ) -> tuple [jax .Array , jax .Array , jax .Array , jax .Array ]:
275279 target_hidden_states = self .combine_hidden_states_fn (
276280 self .state , target_hidden_states )
277281
278282 input_ids , last_token_indices = self ._prepare_input_ids (
279- query_start_loc , target_token_ids , next_token_ids ,
280- self .runner .input_batch .num_reqs )
283+ query_start_loc , target_token_ids , next_token_ids , num_reqs )
281284 # NOTE(pooyam): For now, we don't support multimodal.
282285
283286 block_tables = block_tables .reshape (- 1 )
0 commit comments