@@ -146,26 +146,12 @@ def _substitute_placeholder_token(
146146 f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \
147147 f"Got: { input_ids .shape = } , { token_in_tpu_cur_input_indices .shape = } , { token_in_tpu_pre_next_tokens_indices .shape = } "
148148
149- def updated_input_ids_array (i : int ,
150- current_input_ids : jax .Array ) -> jax .Array :
151- """
152- Iteratively updates the input_ids for all placeholders.
153- å
154- Args:
155- i: The current loop index.
156- current_input_ids: The loop carry state (the input_ids being modified).
157-
158- Returns:
159- The updated input_ids array.
160- """
161- update_idx = token_in_tpu_cur_input_indices [i ]
162- value_idx = token_in_tpu_pre_next_tokens_indices [i ]
163- new_token_value = next_tokens [value_idx ]
164- return current_input_ids .at [update_idx ].set (new_token_value )
165-
166- return jax .lax .fori_loop (0 , placeholder_num , updated_input_ids_array ,
167- input_ids )
168-
149+ # updates the input_ids for all placeholders.
150+ mask = jnp .arange (input_ids .shape [0 ]) < placeholder_num
151+ new_token_values = next_tokens [token_in_tpu_pre_next_tokens_indices ]
152+ original_values = input_ids [token_in_tpu_cur_input_indices ]
153+ update_values = jnp .where (mask , new_token_values , original_values )
154+ return input_ids .at [token_in_tpu_cur_input_indices ].set (update_values )
169155
170156class TPUModelRunner (KVConnectorModelRunnerMixin , LoRAModelRunnerMixin ):
171157
@@ -940,9 +926,9 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
940926 assert self ._pre_async_results is not None
941927 idx_pad_len = len (input_ids ) - len (token_in_tpu_cur_input_indices )
942928 padded_token_in_tpu_cur_input_indices = np .pad (
943- token_in_tpu_cur_input_indices , (0 , idx_pad_len ))
929+ token_in_tpu_cur_input_indices , (0 , idx_pad_len ), mode = 'constant' , constant_values = - 1 )
944930 padded_token_in_tpu_pre_next_tokens_indices = np .pad (
945- token_in_tpu_pre_next_tokens_indices , (0 , idx_pad_len ))
931+ token_in_tpu_pre_next_tokens_indices , (0 , idx_pad_len ), mode = 'constant' , constant_values = - 1 )
946932 with self .maybe_forbid_compile :
947933 input_ids = self ._substitute_placeholder_token_fn (
948934 input_ids , padded_token_in_tpu_cur_input_indices ,
0 commit comments