Skip to content

Commit f6677ef

Browse files
committed
Fix bugs
Signed-off-by: Lihao Ran <[email protected]>
1 parent 82f849a commit f6677ef

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -520,15 +520,11 @@ def prepare_inputs_in_jit_wrapper(
520520
block_tables,
521521
):
522522
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._prepare_inputs_in_jit(
523-
token_indices_cpu,
524-
new_query_start_loc_cpu,
525-
new_seq_lens_cpu,
526-
input_ids,
527-
aux_hidden_states,
528-
attention_metadata,
529-
next_token_ids,
530-
block_tables,
531-
)
523+
token_indices_cpu, new_query_start_loc_cpu,
524+
new_seq_lens_cpu, input_ids, aux_hidden_states,
525+
attention_metadata, next_token_ids, block_tables,
526+
jnp.asarray([self.runner.input_batch.num_reqs],
527+
dtype=jnp.int32))
532528
return target_hidden_states, input_ids, last_token_indices
533529

534530
token_indices_cpu = np.ones((num_tokens, ), dtype=np.int32)
@@ -605,6 +601,8 @@ def draft_model_fn_wrapper(
605601
target_token_ids,
606602
next_token_ids,
607603
block_tables_unreshaped,
604+
jnp.asarray([self.runner.input_batch.num_reqs],
605+
dtype=jnp.int32),
608606
num_tokens=num_tokens,
609607
)
610608

tpu_inference/spec_decode/jax/eagle3.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@ def _concate_hidden_states(
6161
# Concat aux hidden states along feature dim.
6262
return jnp.concatenate(aux_hidden_states, axis=-1)
6363

64-
def _prepare_input_ids(self, query_start_loc: jax.Array,
65-
target_token_ids: jax.Array,
66-
next_token_ids: jax.Array,
67-
num_reqs: int) -> tuple[jnp.ndarray, jnp.ndarray]:
64+
@functools.partial(jax.jit, static_argnums=(0, ))
65+
def _prepare_input_ids(
66+
self, query_start_loc: jax.Array, target_token_ids: jax.Array,
67+
next_token_ids: jax.Array,
68+
num_reqs: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]:
6869
"""JIT-compiled helper for preparing the input IDs for the draft model."""
6970

7071
last_token_indices = query_start_loc[1:] - 1
@@ -152,17 +153,17 @@ def prepare_inputs(
152153
draft_kv_cache_group_id = num_kv_cache_groups - 1
153154
block_tables = self.runner.input_batch.block_table[
154155
draft_kv_cache_group_id].get_device_tensor()
156+
# Number of active requests in this step (un-padded count).
157+
num_reqs = self.runner.input_batch.num_reqs
158+
155159
if num_rejected_tokens is None:
156160
target_hidden_states, input_ids, last_token_indices, block_tables = self._prepare_draft_inputs_in_jit(
157161
self._concate_hidden_states(aux_hidden_states),
158162
attn_metadata.query_start_loc, input_ids, next_token_ids,
159-
block_tables)
163+
block_tables, jnp.asarray([num_reqs], dtype=jnp.int32))
160164
attn_metadata = replace(attn_metadata, block_tables=block_tables)
161165
return target_hidden_states, input_ids, last_token_indices, attn_metadata
162166

163-
# Number of active requests in this step (un-padded count).
164-
num_reqs = self.runner.input_batch.num_reqs
165-
166167
# Host copies from the metadata prepared by the runner.
167168
query_start_loc_cpu = attn_metadata.query_start_loc_cpu
168169
seq_lens_cpu = attn_metadata.seq_lens_cpu
@@ -217,16 +218,14 @@ def prepare_inputs(
217218
# Update seq_lens for active requests only: new_seq_lens = s - n.
218219
new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
219220

220-
token_indices = jnp.asarray(token_indices_cpu, dtype=jnp.int32)
221-
query_start_loc, seq_lens = device_array(self.mesh, (
222-
new_query_start_loc_cpu,
223-
new_seq_lens_cpu,
224-
))
221+
query_start_loc, seq_lens, token_indices = device_array(
222+
self.mesh,
223+
(new_query_start_loc_cpu, new_seq_lens_cpu, token_indices_cpu))
225224

226-
return self._prepare_inputs_in_jit(token_indices, query_start_loc,
227-
seq_lens, input_ids,
228-
aux_hidden_states, attn_metadata,
229-
next_token_ids, block_tables)
225+
return self._prepare_inputs_in_jit(
226+
token_indices, query_start_loc, seq_lens, input_ids,
227+
aux_hidden_states, attn_metadata, next_token_ids, block_tables,
228+
jnp.asarray([num_reqs], dtype=jnp.int32))
230229

231230
@functools.partial(jax.jit, static_argnums=(0, ))
232231
def _prepare_inputs_in_jit(
@@ -239,6 +238,7 @@ def _prepare_inputs_in_jit(
239238
attn_metadata: AttentionMetadata,
240239
next_token_ids: jax.Array,
241240
block_tables: jax.Array,
241+
num_reqs: jax.Array,
242242
) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]:
243243

244244
# Select tokens and hidden states.
@@ -261,23 +261,26 @@ def _prepare_inputs_in_jit(
261261

262262
target_hidden_states, input_ids, last_token_indices, block_tables = self._prepare_draft_inputs_in_jit(
263263
target_hidden_states, query_start_loc, target_token_ids,
264-
next_token_ids, block_tables)
264+
next_token_ids, block_tables, num_reqs)
265265

266266
attn_metadata = replace(attn_metadata, block_tables=block_tables)
267267
return target_hidden_states, input_ids, last_token_indices, attn_metadata
268268

269269
@functools.partial(jax.jit, static_argnums=(0, ))
270270
def _prepare_draft_inputs_in_jit(
271-
self, target_hidden_states: jax.Array, query_start_loc: jax.Array,
272-
target_token_ids: jax.Array, next_token_ids: jax.Array,
273-
block_tables: jax.Array
271+
self,
272+
target_hidden_states: jax.Array,
273+
query_start_loc: jax.Array,
274+
target_token_ids: jax.Array,
275+
next_token_ids: jax.Array,
276+
block_tables: jax.Array,
277+
num_reqs: jax.Array,
274278
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
275279
target_hidden_states = self.combine_hidden_states_fn(
276280
self.state, target_hidden_states)
277281

278282
input_ids, last_token_indices = self._prepare_input_ids(
279-
query_start_loc, target_token_ids, next_token_ids,
280-
self.runner.input_batch.num_reqs)
283+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
281284
# NOTE(pooyam): For now, we don't support multimodal.
282285

283286
block_tables = block_tables.reshape(-1)

0 commit comments

Comments
 (0)