@@ -150,6 +150,13 @@ def _substitute_placeholder_token(
150150 next_tokens : jax .Array , placeholder_num : int ):
151151 """Substitute placeholder tokens from TPU for async scheduler
152152
153+ Padding for parallelisation of the substitute_placeholder_token_fn
154+ [1, 3] => [1, 3, 0, 2, 4, 5, 6, 7, 8]
155+ The reason for such a special padding instead of padding with -1 is:
156+ An edge case when the end index needs to be updated and padding is required.
157+ If we pad the array with -1, the _substitute_placeholder_token_fn will repeatedly update the end element with the original value
158+ Although such a scenario is unlikely to happen in vLLM, it is best to eliminate any potential risks.
159+
153160 Args:
154161 input_ids: possible input_ids size
155162 token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
@@ -1011,10 +1018,14 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
10111018 token_in_tpu_cur_input_indices ) > 0 :
10121019 assert self ._pre_async_results is not None
10131020 idx_pad_len = len (input_ids ) - len (token_in_tpu_cur_input_indices )
1014- padded_token_in_tpu_cur_input_indices = np .pad (
1015- token_in_tpu_cur_input_indices , (0 , idx_pad_len ),
1016- mode = 'constant' ,
1017- constant_values = - 1 )
1021+
1022+ # Pad according to the instructions written inside self._substitute_placeholder_token_fn
1023+ full_range = np .arange (0 , len (input_ids ))
1024+ missing_values = np .setdiff1d (full_range ,
1025+ token_in_tpu_cur_input_indices )
1026+ padded_token_in_tpu_cur_input_indices = np .concatenate (
1027+ (token_in_tpu_cur_input_indices , missing_values ))
1028+
10181029 padded_token_in_tpu_pre_next_tokens_indices = np .pad (
10191030 token_in_tpu_pre_next_tokens_indices , (0 , idx_pad_len ),
10201031 mode = 'constant' ,
0 commit comments