Skip to content

Commit 0ef54cf

Browse files
committed
async scheduler fix _substitute_placeholder_token_fn bug
Signed-off-by: cychiuak <[email protected]>
1 parent 9483188 commit 0ef54cf

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def _substitute_placeholder_token(
133133
next_tokens: jax.Array, placeholder_num: int):
134134
"""Substitute placeholder tokens from TPU for async scheduler
135135
136+
Padding for parallelisation of the substitute_placeholder_token_fn
137+
[1, 3] => [1, 3, 0, 2, 4, 5, 6, 7, 8]
138+
The reason for such a special padding instead of padding with -1 is:
139+
An edge case when the end index needs to be updated and padding is required.
140+
If we pad the array with -1, the _substitute_placeholder_token_fn will repeatedly update the end element with the original value
141+
Although such a scenario is unlikely to happen in vLLM, it is best to eliminate any potential risks.
142+
136143
Args:
137144
input_ids: possible input_ids size
138145
token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
@@ -937,10 +944,14 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
937944
token_in_tpu_cur_input_indices) > 0:
938945
assert self._pre_async_results is not None
939946
idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices)
940-
padded_token_in_tpu_cur_input_indices = np.pad(
941-
token_in_tpu_cur_input_indices, (0, idx_pad_len),
942-
mode='constant',
943-
constant_values=-1)
947+
948+
# Pad according to the instructions written inside self._substitute_placeholder_token_fn
949+
full_range = np.arange(0, len(input_ids))
950+
missing_values = np.setdiff1d(full_range,
951+
token_in_tpu_cur_input_indices)
952+
padded_token_in_tpu_cur_input_indices = np.concatenate(
953+
(token_in_tpu_cur_input_indices, missing_values))
954+
944955
padded_token_in_tpu_pre_next_tokens_indices = np.pad(
945956
token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len),
946957
mode='constant',

0 commit comments

Comments
 (0)