Skip to content

Commit 93d2e08

Browse files
committed
Put data transfer outside of jitted functions
Signed-off-by: Lihao Ran <[email protected]>
1 parent ecd79ac commit 93d2e08

File tree

2 files changed

+33
-61
lines changed

2 files changed

+33
-61
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,8 @@ def _precompile_eagle3_helpers(self) -> None:
483483
target_hidden_state_loop = self._create_dummy_tensor(
484484
(self.runner.max_num_reqs, hidden_size), dtype,
485485
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
486-
new_seq_lens_cpu = np.ones((self.runner.max_num_reqs, ), jnp.int32)
487486
next_token_ids = self._create_dummy_tensor(
488487
(self.runner.max_num_reqs, ), jnp.int32)
489-
new_query_start_loc_cpu = np.ones((self.runner.max_num_reqs + 1, ),
490-
jnp.int32)
491488
last_token_indices = self._create_dummy_tensor(
492489
(self.runner.max_num_reqs, ), jnp.int32)
493490
for num_tokens in self.runner.num_tokens_paddings:
@@ -535,6 +532,7 @@ def prepare_inputs_in_jit_wrapper(
535532
return target_hidden_states, input_ids, last_token_indices
536533

537534
token_indices_cpu = np.ones((num_tokens, ), dtype=np.int32)
535+
token_indices = jnp.asarray(token_indices_cpu, dtype=jnp.int32)
538536
input_ids = self._create_dummy_tensor(
539537
(num_tokens, ), jnp.int32,
540538
NamedSharding(self.runner.mesh, PartitionSpec()))
@@ -555,9 +553,9 @@ def prepare_inputs_in_jit_wrapper(
555553
self._run_compilation(
556554
"eagle3_prepare_inputs_in_jit",
557555
prepare_inputs_in_jit_wrapper,
558-
token_indices_cpu,
559-
new_query_start_loc_cpu,
560-
new_seq_lens_cpu,
556+
token_indices,
557+
query_start_loc,
558+
seq_lens,
561559
input_ids,
562560
aux_hidden_states,
563561
attention_metadata,
@@ -630,13 +628,6 @@ def draft_model_fn_wrapper(
630628
hidden_states = self._create_dummy_tensor(
631629
(num_tokens, hidden_size), jnp.bfloat16,
632630
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
633-
self._run_compilation(
634-
"eagle3_select_draft_token_ids_in_jit",
635-
self.runner.drafter._select_draft_token_ids_in_jit,
636-
hidden_states,
637-
last_token_indices,
638-
num_tokens=num_tokens,
639-
)
640631

641632
self._run_compilation(
642633
"eagle3_select_and_stack_draft_token_ids_in_jit",
@@ -648,9 +639,10 @@ def draft_model_fn_wrapper(
648639

649640
self._run_compilation(
650641
"eagle3_select_positions_and_hidden_states_in_jit",
651-
self.runner.drafter._select_positions_and_hidden_states_in_jit,
642+
self.runner.drafter._select_inputs_for_loop_in_jit,
652643
positions,
653644
hidden_states,
645+
hidden_states,
654646
last_token_indices,
655647
num_tokens=num_tokens,
656648
)

tpu_inference/spec_decode/jax/eagle3.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -217,30 +217,30 @@ def prepare_inputs(
217217
# Update seq_lens for active requests only: new_seq_lens = s - n.
218218
new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
219219

220-
return self._prepare_inputs_in_jit(token_indices_cpu,
221-
new_query_start_loc_cpu,
222-
new_seq_lens_cpu, input_ids,
220+
token_indices = jnp.asarray(token_indices_cpu, dtype=jnp.int32)
221+
query_start_loc, seq_lens = device_array(self.mesh, (
222+
new_query_start_loc_cpu,
223+
new_seq_lens_cpu,
224+
))
225+
226+
return self._prepare_inputs_in_jit(token_indices, query_start_loc,
227+
seq_lens, input_ids,
223228
aux_hidden_states, attn_metadata,
224229
next_token_ids, block_tables)
225230

226231
@functools.partial(jax.jit, static_argnums=(0, ))
227232
def _prepare_inputs_in_jit(
228233
self,
229-
token_indices_cpu: np.ndarray,
230-
new_query_start_loc_cpu: np.ndarray,
231-
new_seq_lens_cpu: np.ndarray,
234+
token_indices: jax.Array,
235+
query_start_loc: jax.Array,
236+
seq_lens: jax.Array,
232237
input_ids: jax.Array,
233238
aux_hidden_states: tuple[jax.Array, ...],
234239
attn_metadata: AttentionMetadata,
235240
next_token_ids: jax.Array,
236241
block_tables: jax.Array,
237242
) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]:
238243

239-
token_indices = jnp.asarray(token_indices_cpu, dtype=jnp.int32)
240-
query_start_loc, seq_lens = device_array(self.mesh, (
241-
new_query_start_loc_cpu,
242-
new_seq_lens_cpu,
243-
))
244244
# Select tokens and hidden states.
245245
target_token_ids = input_ids[token_indices]
246246
target_hidden_states = jnp.concatenate(
@@ -259,7 +259,7 @@ def _prepare_inputs_in_jit(
259259
request_distribution=attn_metadata.request_distribution,
260260
)
261261

262-
target_hidden_states, input_ids, last_token_indices, block_tables = self._prepare_draft_inputs(
262+
target_hidden_states, input_ids, last_token_indices, block_tables = self._prepare_draft_inputs_in_jit(
263263
target_hidden_states, query_start_loc, target_token_ids,
264264
next_token_ids, block_tables)
265265

@@ -272,16 +272,6 @@ def _prepare_draft_inputs_in_jit(
272272
target_token_ids: jax.Array, next_token_ids: jax.Array,
273273
block_tables: jax.Array
274274
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
275-
return self._prepare_draft_inputs(target_hidden_states,
276-
query_start_loc, target_token_ids,
277-
next_token_ids, block_tables)
278-
279-
def _prepare_draft_inputs(
280-
self, target_hidden_states: jax.Array, query_start_loc: jax.Array,
281-
target_token_ids: jax.Array, next_token_ids: jax.Array,
282-
block_tables: jax.Array
283-
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
284-
285275
target_hidden_states = self.combine_hidden_states_fn(
286276
self.state, target_hidden_states)
287277

@@ -294,13 +284,13 @@ def _prepare_draft_inputs(
294284

295285
return target_hidden_states, input_ids, last_token_indices, block_tables
296286

297-
@functools.partial(jax.jit, static_argnums=(0, ))
298-
def _select_draft_token_ids_in_jit(
287+
def _select_draft_token_ids(
299288
self,
300289
hidden_states: jax.Array,
301290
last_token_indices: jax.Array,
302291
) -> jax.Array:
303-
return self._select_draft_token_ids(hidden_states, last_token_indices)
292+
sample_hidden_states = hidden_states[last_token_indices]
293+
return self._get_draft_token_ids_in_jit(sample_hidden_states)
304294

305295
@functools.partial(jax.jit, static_argnums=(0, ))
306296
def _select_and_stack_draft_token_ids_in_jit(
@@ -312,30 +302,22 @@ def _select_and_stack_draft_token_ids_in_jit(
312302
last_token_indices)
313303
return jnp.stack([draft_token_ids], axis=1)
314304

315-
def _select_draft_token_ids(
316-
self,
317-
hidden_states: jax.Array,
318-
last_token_indices: jax.Array,
319-
) -> jax.Array:
320-
sample_hidden_states = hidden_states[last_token_indices]
321-
return self._get_draft_token_ids(sample_hidden_states)
322-
323-
def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
305+
@functools.partial(jax.jit, static_argnums=(0, ))
306+
def _get_draft_token_ids_in_jit(self,
307+
hidden_states: jax.Array) -> jax.Array:
324308
lora_metadata = None
325309
logits = self.compute_logits_fn(self.state, hidden_states,
326310
lora_metadata)
327311
return jnp.argmax(logits, axis=-1)
328312

329313
@functools.partial(jax.jit, static_argnums=(0, ))
330-
def _get_draft_token_ids_in_jit(self,
331-
hidden_states: jax.Array) -> jax.Array:
332-
return self._get_draft_token_ids(hidden_states)
333-
334-
@functools.partial(jax.jit, static_argnums=(0, ))
335-
def _select_positions_and_hidden_states_in_jit(
336-
self, positions: jax.Array, hidden_states: jax.Array,
314+
def _select_inputs_for_loop_in_jit(
315+
self, positions: jax.Array, residual: jax.Array,
316+
hidden_states: jax.Array,
337317
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
338-
return positions[last_token_indices], hidden_states[last_token_indices]
318+
return positions[last_token_indices], residual[
319+
last_token_indices], self._select_draft_token_ids(
320+
hidden_states, last_token_indices)
339321

340322
def propose(
341323
self,
@@ -364,14 +346,12 @@ def propose(
364346
return kv_caches, self._select_and_stack_draft_token_ids_in_jit(
365347
hidden_states, last_token_indices)
366348

367-
draft_token_ids = self._select_draft_token_ids_in_jit(
368-
hidden_states, last_token_indices)
349+
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_in_jit(
350+
attn_metadata.input_positions, residual[0], hidden_states,
351+
last_token_indices)
369352

370353
draft_token_ids_list = [draft_token_ids]
371354

372-
positions, hidden_states = self._select_positions_and_hidden_states_in_jit(
373-
attn_metadata.input_positions, residual[0], last_token_indices)
374-
375355
for _ in range(self.num_speculative_tokens - 1):
376356
input_ids_loop = draft_token_ids_list[-1]
377357

0 commit comments

Comments
 (0)