Skip to content

Commit 1a21402

Browse files
committed
async code optimisation
1 parent db4a874 commit 1a21402

File tree

2 files changed

+14
-24
lines changed

2 files changed

+14
-24
lines changed

tests/e2e/test_async_scheduler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def _test_performance_helper(
7676
max_model_len=800,
7777
max_num_seqs=24,
7878
max_num_batched_tokens=512,
79-
enable_prefix_caching=False)
79+
enable_prefix_caching=False,
80+
async_scheduling=0)
8081

8182
start_time = time.time()
8283
_ = ref_llm.generate(test_prompts, sampling_config)
@@ -144,7 +145,10 @@ def _test_correctness_helper(
144145
with monkeypatch.context():
145146
test_prompts = get_test_prompts()
146147

147-
ref_llm = LLM(model=model_name, max_model_len=1024, max_num_seqs=100)
148+
ref_llm = LLM(model=model_name,
149+
max_model_len=1024,
150+
max_num_seqs=100,
151+
async_scheduling=0)
148152
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
149153

150154
del ref_llm

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -146,26 +146,12 @@ def _substitute_placeholder_token(
146146
f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \
147147
f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}"
148148

149-
def updated_input_ids_array(i: int,
150-
current_input_ids: jax.Array) -> jax.Array:
151-
"""
152-
Iteratively updates the input_ids for all placeholders.
153-
å
154-
Args:
155-
i: The current loop index.
156-
current_input_ids: The loop carry state (the input_ids being modified).
157-
158-
Returns:
159-
The updated input_ids array.
160-
"""
161-
update_idx = token_in_tpu_cur_input_indices[i]
162-
value_idx = token_in_tpu_pre_next_tokens_indices[i]
163-
new_token_value = next_tokens[value_idx]
164-
return current_input_ids.at[update_idx].set(new_token_value)
165-
166-
return jax.lax.fori_loop(0, placeholder_num, updated_input_ids_array,
167-
input_ids)
168-
149+
# updates the input_ids for all placeholders.
150+
mask = jnp.arange(input_ids.shape[0]) < placeholder_num
151+
new_token_values = next_tokens[token_in_tpu_pre_next_tokens_indices]
152+
original_values = input_ids[token_in_tpu_cur_input_indices]
153+
update_values = jnp.where(mask, new_token_values, original_values)
154+
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
169155

170156
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
171157

@@ -940,9 +926,9 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
940926
assert self._pre_async_results is not None
941927
idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices)
942928
padded_token_in_tpu_cur_input_indices = np.pad(
943-
token_in_tpu_cur_input_indices, (0, idx_pad_len))
929+
token_in_tpu_cur_input_indices, (0, idx_pad_len), mode='constant', constant_values=-1)
944930
padded_token_in_tpu_pre_next_tokens_indices = np.pad(
945-
token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len))
931+
token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len), mode='constant', constant_values=-1)
946932
with self.maybe_forbid_compile:
947933
input_ids = self._substitute_placeholder_token_fn(
948934
input_ids, padded_token_in_tpu_cur_input_indices,

0 commit comments

Comments
 (0)