Skip to content

Commit 182e30c

Browse files
committed
Resolve comments
Signed-off-by: Lihao Ran <[email protected]>
1 parent 93d2e08 commit 182e30c

File tree

3 files changed

+7
-22
lines changed

3 files changed

+7
-22
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -630,15 +630,7 @@ def draft_model_fn_wrapper(
630630
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
631631

632632
self._run_compilation(
633-
"eagle3_select_and_stack_draft_token_ids_in_jit",
634-
self.runner.drafter._select_and_stack_draft_token_ids_in_jit,
635-
hidden_states,
636-
last_token_indices,
637-
num_tokens=num_tokens,
638-
)
639-
640-
self._run_compilation(
641-
"eagle3_select_positions_and_hidden_states_in_jit",
633+
"eagle3_select_inputs_for_loop_in_jit",
642634
self.runner.drafter._select_inputs_for_loop_in_jit,
643635
positions,
644636
hidden_states,

tpu_inference/runner/speculative_decoding_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,10 @@ def propose_eagle3_draft_token_ids(
142142
last_token_indices=last_token_indices,
143143
target_hidden_states=target_hidden_states,
144144
)
145-
result = draft_token_ids.tolist()
146-
return result
145+
draft_token_ids = np.array(draft_token_ids)
146+
if draft_token_ids.ndim == 1:
147+
draft_token_ids = np.expand_dims(draft_token_ids, axis=-1)
148+
return draft_token_ids.tolist()
147149

148150
def get_spec_decode_metadata(
149151
self,

tpu_inference/spec_decode/jax/eagle3.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def _prepare_draft_inputs_in_jit(
284284

285285
return target_hidden_states, input_ids, last_token_indices, block_tables
286286

287+
@functools.partial(jax.jit, static_argnums=(0, ))
287288
def _select_draft_token_ids(
288289
self,
289290
hidden_states: jax.Array,
@@ -292,16 +293,6 @@ def _select_draft_token_ids(
292293
sample_hidden_states = hidden_states[last_token_indices]
293294
return self._get_draft_token_ids_in_jit(sample_hidden_states)
294295

295-
@functools.partial(jax.jit, static_argnums=(0, ))
296-
def _select_and_stack_draft_token_ids_in_jit(
297-
self,
298-
hidden_states: jax.Array,
299-
last_token_indices: jax.Array,
300-
) -> jax.Array:
301-
draft_token_ids = self._select_draft_token_ids(hidden_states,
302-
last_token_indices)
303-
return jnp.stack([draft_token_ids], axis=1)
304-
305296
@functools.partial(jax.jit, static_argnums=(0, ))
306297
def _get_draft_token_ids_in_jit(self,
307298
hidden_states: jax.Array) -> jax.Array:
@@ -343,7 +334,7 @@ def propose(
343334
)
344335

345336
if self.num_speculative_tokens == 1:
346-
return kv_caches, self._select_and_stack_draft_token_ids_in_jit(
337+
return kv_caches, self._select_draft_token_ids(
347338
hidden_states, last_token_indices)
348339

349340
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_in_jit(

0 commit comments

Comments
 (0)