Skip to content

Commit 10f13d0

Browse files
committed
async scheduler fix _substitute_placeholder_token_fn bug
Signed-off-by: cychiuak <[email protected]>
1 parent aad4c55 commit 10f13d0

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ def _substitute_placeholder_token(
150150
next_tokens: jax.Array, placeholder_num: int):
151151
"""Substitute placeholder tokens from TPU for async scheduler
152152
153+
Padding for parallelisation of the substitute_placeholder_token_fn
154+
[1, 3] => [1, 3, 0, 2, 4, 5, 6, 7, 8]
155+
The reason for such a special padding instead of padding with -1 is:
156+
An edge case when the end index needs to be updated and padding is required.
157+
If we pad the array with -1, the _substitute_placeholder_token_fn will repeatedly update the end element with the original value
158+
Although such a scenario is unlikely to happen in vLLM, it is best to eliminate any potential risks.
159+
153160
Args:
154161
input_ids: possible input_ids size
155162
token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
@@ -1011,10 +1018,14 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
10111018
token_in_tpu_cur_input_indices) > 0:
10121019
assert self._pre_async_results is not None
10131020
idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices)
1014-
padded_token_in_tpu_cur_input_indices = np.pad(
1015-
token_in_tpu_cur_input_indices, (0, idx_pad_len),
1016-
mode='constant',
1017-
constant_values=-1)
1021+
1022+
# Pad according to the instructions written inside self._substitute_placeholder_token_fn
1023+
full_range = np.arange(0, len(input_ids))
1024+
missing_values = np.setdiff1d(full_range,
1025+
token_in_tpu_cur_input_indices)
1026+
padded_token_in_tpu_cur_input_indices = np.concatenate(
1027+
(token_in_tpu_cur_input_indices, missing_values))
1028+
10181029
padded_token_in_tpu_pre_next_tokens_indices = np.pad(
10191030
token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len),
10201031
mode='constant',

0 commit comments

Comments
 (0)