@@ -133,6 +133,13 @@ def _substitute_placeholder_token(
133133 next_tokens : jax .Array , placeholder_num : int ):
134134 """Substitute placeholder tokens from TPU for async scheduler
135135
136+ Padding for parallelisation of the substitute_placeholder_token_fn
137+ [1, 3] => [1, 3, 0, 2, 4, 5, 6, 7, 8]
138+ The reason for such a special padding instead of padding with -1 is:
139+ An edge case when the end index needs to be updated and padding is required.
140+ If we pad the array with -1, the _substitute_placeholder_token_fn will repeatedly update the end element with the original value
141+ Although such a scenario is unlikely to happen in vLLM, it is best to eliminate any potential risks.
142+
136143 Args:
137144 input_ids: possible input_ids size
138145 token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
@@ -937,10 +944,14 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
937944 token_in_tpu_cur_input_indices ) > 0 :
938945 assert self ._pre_async_results is not None
939946 idx_pad_len = len (input_ids ) - len (token_in_tpu_cur_input_indices )
940- padded_token_in_tpu_cur_input_indices = np .pad (
941- token_in_tpu_cur_input_indices , (0 , idx_pad_len ),
942- mode = 'constant' ,
943- constant_values = - 1 )
947+
948+ # Pad according to the instructions written inside self._substitute_placeholder_token_fn
949+ full_range = np .arange (0 , len (input_ids ))
950+ missing_values = np .setdiff1d (full_range ,
951+ token_in_tpu_cur_input_indices )
952+ padded_token_in_tpu_cur_input_indices = np .concatenate (
953+ (token_in_tpu_cur_input_indices , missing_values ))
954+
944955 padded_token_in_tpu_pre_next_tokens_indices = np .pad (
945956 token_in_tpu_pre_next_tokens_indices , (0 , idx_pad_len ),
946957 mode = 'constant' ,
0 commit comments