diff --git a/tests/spec_decode/test_eagle3.py b/tests/spec_decode/test_eagle3.py index 730bb807c..3a0340661 100644 --- a/tests/spec_decode/test_eagle3.py +++ b/tests/spec_decode/test_eagle3.py @@ -74,13 +74,22 @@ def test_prepare_inputs(): proposer = _create_proposer("eagle3", 1) num_reqs = 3 max_num_seqs = 128 + max_num_blocks_per_req = 10 # Mock value # Mock runner attributes proposer.runner.input_batch.num_reqs = num_reqs proposer.runner.num_tokens_paddings = runner_utils.get_token_paddings( min_token_size=16, max_token_size=1024, padding_gap=0) - proposer.runner._select_from_array_fn = ( - lambda array, indices: array[indices]) + + # Mocks required by _prepare_draft_inputs helper + proposer.combine_hidden_states_fn = lambda state, h: h # Mock passthrough + proposer.state = None # Mock state + proposer.runner.input_batch.block_table = [mock.MagicMock()] + # Mock the block table return value (2D array) + (proposer.runner.input_batch.block_table[0].get_device_tensor.return_value + ) = jnp.zeros((num_reqs, max_num_blocks_per_req), dtype=jnp.int32) + + # --- Setup sequence data --- qsl_cpu = np.zeros(max_num_seqs + 1, dtype=np.int32) query_lens = np.zeros(max_num_seqs, dtype=np.int32) query_lens[:num_reqs] = [4, 7, 5] @@ -102,12 +111,17 @@ def test_prepare_inputs(): num_rejected_tokens_cpu = np.zeros(max_num_seqs, dtype=np.int32) num_rejected_tokens_cpu[:num_reqs] = [1, 3, 2] num_rejected_tokens = jnp.array(num_rejected_tokens_cpu) + # This is only used in the _prepare_input_ids helper + # It must be padded to max_num_seqs (128) to match the mask in jnp.where + next_token_ids_cpu = np.zeros(max_num_seqs, dtype=np.int32) + next_token_ids_cpu[:num_reqs] = [1, 2, 3] # Valid tokens for active reqs + next_token_ids = jnp.array(next_token_ids_cpu) attn_metadata = AttentionMetadata( seq_lens=jnp.array(sl_cpu), input_positions=jnp.arange(total_tokens), query_start_loc=jnp.array(qsl_cpu), - block_tables=jnp.array([]), + block_tables=jnp.array([]), # This will be replaced by the mock request_distribution=None, ) attn_metadata.query_start_loc_cpu = qsl_cpu @@ -129,20 +143,26 @@ def test_prepare_inputs(): expected_total_tokens = runner_utils.get_padded_token_len( proposer.runner.num_tokens_paddings, expected_total_tokens) + expected_last_token_indices = jnp.array(expected_new_qsl[1:] - 1) + # Execute - updated_metadata, target_token_ids, target_hidden_states = ( + target_hidden_states, input_ids, last_token_indices, updated_metadata = ( proposer.prepare_inputs(attn_metadata, input_ids, aux_hidden_states, - num_rejected_tokens)) + next_token_ids, num_rejected_tokens)) # Assertions assert jnp.array_equal(updated_metadata.query_start_loc, jnp.array(expected_new_qsl)) assert jnp.array_equal(updated_metadata.seq_lens, jnp.array(expected_new_seq_lens)) - assert target_token_ids.shape == (expected_total_tokens, ) + + assert jnp.array_equal(last_token_indices, expected_last_token_indices) + + assert input_ids.shape == (expected_total_tokens, ) # NOTE: We don't check the content of target_token_ids for padded requests # as it's complicated to construct the expected tensor. The shape check # and the qsl/seq_len checks are sufficient to validate the logic. + # The concatenated hidden state shape should be (..., hidden_size * 3) assert target_hidden_states.shape == (expected_total_tokens, hidden_size * 3) @@ -161,29 +181,59 @@ def test_propose(method, num_speculative_tokens): total_tokens = seq_len_1 + seq_len_2 base_token_ids = [42, 60] - def mock_model_fn(state, kv_caches, input_ids, hidden_states, + def mock_model_fn(state, kv_caches, input_ids, target_hidden_states, attn_metadata): + """ + Mock model_fn. + Returns: (kv_caches, hidden_states_for_logits, residual_tuple) + + - On first call (num_tokens == total_tokens): + Populate hidden_states_for_logits[last_token_indices] with base_token_ids. + Populate residual_tuple[0][last_token_indices] with base_token_ids. + - On loop calls (num_tokens == batch_size): + Use input_ids (previous draft token) to generate new token (input_ids + 1). + Populate hidden_states_for_logits with (input_ids + 1). + Populate residual_tuple[0] with (input_ids + 1). + """ num_tokens = input_ids.shape[0] - new_hidden_states = jnp.zeros((num_tokens, hidden_size)) + + # This will be used for logits (output 2) + hidden_states_for_logits = jnp.zeros((num_tokens, hidden_size)) + # This will be fed into the next step (output 3, item 0) + residual_hidden_states = jnp.zeros((num_tokens, hidden_size)) if num_tokens == total_tokens: - # First call in propose. Set hidden states for last tokens - # to produce the first draft tokens. + # First call in propose. + # `propose` will select from last_token_indices. last_token_indices = attn_metadata.query_start_loc[1:] - 1 - # The proposer uses next_token_ids to set the last token of each - # sequence in input_ids. We mock this behavior by directly using - # next_token_ids to generate the first draft tokens. - # The mock `compute_logits` will use hidden_states[:, 0] - # to generate tokens. - new_hidden_states = new_hidden_states.at[ + + # Set logits output + hidden_states_for_logits = hidden_states_for_logits.at[ last_token_indices, 0].set(jnp.array(base_token_ids)) - else: # Subsequent calls in the loop - new_hidden_states = new_hidden_states.at[:, 0].set(input_ids + 1) - return kv_caches, new_hidden_states, new_hidden_states + # Set residual for next step + residual_hidden_states = residual_hidden_states.at[ + last_token_indices, 0].set(jnp.array(base_token_ids)) + else: + # Subsequent calls in the loop + # input_ids is the previous draft token (shape `batch_size`) + # Mock logic: next token = previous token + 1 + next_token_ids_encoded = input_ids + 1 + + # Set logits output + hidden_states_for_logits = hidden_states_for_logits.at[:, 0].set( + next_token_ids_encoded) + + # Set residual for next step + residual_hidden_states = residual_hidden_states.at[:, 0].set( + next_token_ids_encoded) + + # Return (kv_caches, hidden_states, residual_tuple) + return kv_caches, hidden_states_for_logits, (residual_hidden_states, ) def mock_compute_logits_fn(state, hidden_states, lora_metadata): # Create deterministic logits from hidden_states. + # Takes the value from hidden_states[:, 0] token_ids = hidden_states[:, 0].astype(jnp.int32) return jax.nn.one_hot(token_ids, vocab_size) @@ -198,39 +248,55 @@ def mock_combine_hidden_states_fn(state, hidden_states): # Inputs kv_caches = [None] * 1 # Mock kv_caches - next_token_ids = jnp.array(base_token_ids, dtype=jnp.int32) + + # Create the 2D table first, as this is what the (unused) mock expects + block_tables_2d = jnp.zeros((batch_size, 10), dtype=jnp.int32) + attn_metadata = AttentionMetadata( seq_lens=jnp.array([seq_len_1, seq_len_2]), input_positions=jnp.concatenate( [jnp.arange(seq_len_1), jnp.arange(seq_len_2)]), query_start_loc=jnp.array([0, seq_len_1, total_tokens]), - block_tables=jnp.zeros((2, 10), dtype=jnp.int32), + # Pass the FLATTENED table to simulate output of prepare_inputs + block_tables=block_tables_2d.reshape(-1), request_distribution=None, ) + + # These are the inputs to `propose` + # input_ids (from prepare_inputs) target_token_ids = jnp.zeros(total_tokens, dtype=jnp.int32) + # target_hidden_states (from prepare_inputs) target_hidden_states = jnp.zeros((total_tokens, hidden_size)) + # last_token_indices (from prepare_inputs) + last_token_indices = attn_metadata.query_start_loc[1:] - 1 # Mock runner for block tables + # This mock isn't actually used by propose(), but we'll set it + # to the 2D table for correctness, as that's what + # _prepare_draft_inputs (called by prepare_inputs) would expect. proposer.runner.input_batch.num_reqs = batch_size proposer.runner.input_batch.block_table = [mock.MagicMock()] (proposer.runner.input_batch.block_table[0].get_device_tensor.return_value - ) = attn_metadata.block_tables - proposer.runner._select_from_array_fn = ( - lambda array, indices: array[indices]) + ) = block_tables_2d # Execute _, draft_token_ids = proposer.propose( kv_caches, - next_token_ids, - attn_metadata, target_token_ids, + attn_metadata, + last_token_indices, target_hidden_states, ) # Assertions assert draft_token_ids.shape == (batch_size, num_speculative_tokens) + # Check the generated tokens + # Step 0: base_token_ids [42, 60] + # Step 1: [43, 61] + # Step 2: [44, 62] + # ... expected_tokens = np.zeros((batch_size, num_speculative_tokens), dtype=np.int64) for i in range(batch_size): diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index c7706aa0d..769d31d96 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -7,7 +7,6 @@ import numpy as np import vllm.envs as envs from jax.sharding import NamedSharding, PartitionSpec -from vllm.utils import cdiv from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata @@ -125,7 +124,7 @@ def model_fn_wrapper( layer_name_to_kvcache_index, lora_metadata, ): - kv_caches, hidden_states, aux_hidden_states = self.runner.model_fn( + kv_caches, hidden_states, _ = self.runner.model_fn( state, kv_caches, input_ids, attention_metadata, inputs_embeds, layer_name_to_kvcache_index, lora_metadata) self.runner.kv_caches = kv_caches @@ -263,16 +262,6 @@ def _precompile_select_from_array(self) -> None: only_equal_paddings=True, ) - self._precompile_select_from_array_helper( - name="select hidden states for eagle-3", - source_paddings=self.runner.num_tokens_paddings, - indices_paddings=[self.runner.max_num_reqs], - hidden_dim=hsize, - sharding=NamedSharding(self.runner.mesh, - PartitionSpec(None, None)), - check_should_skip_padding=False, - ) - def _precompile_compute_logits(self) -> None: logger.info("Compiling compute_logits with different input shapes.") hsize = self.runner.model_config.get_hidden_size() @@ -432,22 +421,14 @@ def _precompile_eagle3_helpers(self) -> None: logger.info( "Compiling eagle3 jitted helpers with different input shapes.") hidden_size = self.runner.model_config.get_hidden_size() - draft_hidden_size = self.runner.vllm_config.speculative_config.draft_model_config.hf_config.hidden_size * 3 dtype = self.runner.model_config.dtype num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups) draft_kv_cache_group_id = num_kv_cache_groups - 1 - block_tables = jnp.ones( - (self.runner.max_num_reqs, - cdiv(self.runner.max_model_len, self.runner.block_size)), - jnp.int32) - self._run_compilation( - "eagle3_reshape_block", - self.runner.drafter._reshape_block_tables, - block_tables, - ) block_tables = self.runner.input_batch.block_table[ - draft_kv_cache_group_id].get_device_tensor().reshape(-1) + draft_kv_cache_group_id].get_cpu_tensor().reshape(-1) + block_tables_first_spec = jax.device_put( + block_tables, NamedSharding(self.runner.mesh, PartitionSpec())) block_tables_loop = jax.device_put( block_tables, NamedSharding(self.runner.mesh, PartitionSpec(None, ))) @@ -458,28 +439,20 @@ def _precompile_eagle3_helpers(self) -> None: jnp.int32) query_start_loc = self._create_dummy_tensor( (self.runner.max_num_reqs + 1, ), jnp.int32) - self._run_compilation("_prepare_input_loop for the first loop", - self.runner.drafter._prepare_input_loop, - selected_positions, seq_lens, block_tables) - self._run_compilation("_prepare_input_loop for the subsequent loops", - self.runner.drafter._prepare_input_loop, - selected_positions, seq_lens, block_tables_loop) + self._run_compilation( + "_update_inputs_for_loop_speculation for the first loop", + self.runner.drafter._update_inputs_for_loop_speculation, + selected_positions, seq_lens, block_tables) + self._run_compilation( + "_update_inputs_for_loop_speculation for the subsequent loops", + self.runner.drafter._update_inputs_for_loop_speculation, + selected_positions, seq_lens, block_tables_loop) request_distribution = np.array([0, 0, 0], dtype=np.int32) request_distribution = device_array(self.runner.mesh, request_distribution) for num_reqs_padding in self.runner.num_reqs_paddings: - logits = self._create_dummy_tensor( - (num_reqs_padding, self.runner.vocab_size), jnp.bfloat16, - NamedSharding(self.runner.mesh, PartitionSpec(None, "model"))) - self._run_compilation( - "_get_draft_token_ids", - self.runner.drafter._get_draft_token_ids, - logits, - num_reqs=num_reqs_padding, - ) - for i in range(1, self.runner.drafter.num_speculative_tokens + 1): draft_token_ids_list = [ self._create_dummy_tensor( @@ -488,7 +461,7 @@ def _precompile_eagle3_helpers(self) -> None: for _ in range(i) ] self._run_compilation( - "_stack_draft_token_ids", + "eagle3_stack_draft_token_ids", self.runner.drafter._stack_draft_token_ids, draft_token_ids_list, num_reqs=num_reqs_padding, @@ -498,46 +471,59 @@ def _precompile_eagle3_helpers(self) -> None: hidden_states = self._create_dummy_tensor( (num_logits, hidden_size), jnp.bfloat16) self._run_compilation( - "drafter_compute_logits", - self.runner.drafter.compute_logits_fn, - self.runner.drafter.state, + "eagle3_get_draft_token_ids", + self.runner.drafter._get_draft_token_ids, hidden_states, - None, num_logits=num_logits, ) - position_indices = self._create_dummy_tensor( - (self.runner.max_num_reqs, ), jnp.int32) - next_token_ids = self._create_dummy_tensor( - (self.runner.max_num_reqs, ), jnp.int32) input_ids_loop = self._create_dummy_tensor( (self.runner.max_num_reqs, ), jnp.int32, NamedSharding(self.runner.mesh, PartitionSpec())) target_hidden_state_loop = self._create_dummy_tensor( (self.runner.max_num_reqs, hidden_size), dtype, NamedSharding(self.runner.mesh, PartitionSpec(None, None))) + next_token_ids = self._create_dummy_tensor( + (self.runner.max_num_reqs, ), jnp.int32) + last_token_indices = self._create_dummy_tensor( + (self.runner.max_num_reqs, ), jnp.int32) for num_tokens in self.runner.num_tokens_paddings: - positions = self._create_dummy_tensor((num_tokens, ), jnp.int32) - self._run_compilation( - "select_from_array [select input positions for eagle3]", - self.runner._select_from_array_fn, - positions, - position_indices, - num_tokens=num_tokens) - aux_hidden_states = [ self._create_dummy_tensor((num_tokens, hidden_size), dtype), self._create_dummy_tensor((num_tokens, hidden_size), dtype), self._create_dummy_tensor((num_tokens, hidden_size), dtype), ] - self._run_compilation( - "eagle3_concate_hidden_states", - self.runner.drafter._concate_hidden_states, - aux_hidden_states, - num_tokens=num_tokens, + + positions = self._create_dummy_tensor((num_tokens, ), jnp.int32) + attention_metadata = AttentionMetadata( + input_positions=positions, + block_tables=block_tables_first_spec, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + request_distribution=request_distribution, ) - input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32) + def filter_token_and_prepare_initial_inputs_wrapper( + token_indices, + query_start_loc, + seq_lens, + input_ids, + aux_hidden_states, + attention_metadata, + next_token_ids, + num_reqs, + ): + target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs( + token_indices, query_start_loc, seq_lens, input_ids, + aux_hidden_states, attention_metadata, next_token_ids, + num_reqs) + return target_hidden_states, input_ids, last_token_indices + + token_indices = self._create_dummy_tensor((num_tokens, ), + jnp.int32) + input_ids = self._create_dummy_tensor( + (num_tokens, ), jnp.int32, + NamedSharding(self.runner.mesh, PartitionSpec())) aux_hidden_states = [ self._create_dummy_tensor( (num_tokens, hidden_size), jnp.bfloat16, @@ -552,30 +538,22 @@ def _precompile_eagle3_helpers(self) -> None: NamedSharding(self.runner.mesh, PartitionSpec(None, None))), ] - for num_indices in self.runner.num_tokens_paddings: - indices = jnp.ones((num_indices, ), dtype=jnp.int32) - self._run_compilation( - "select_from_array [select input ids for eagle3]", - self.runner._select_from_array_fn, - input_ids, - indices, - num_tokens=num_tokens, - num_indices=num_indices) - self._run_compilation( - "select_from_array [select hidden states for eagle3]", - self.runner.drafter._select_target_hidden_states, - aux_hidden_states, indices) - - attention_metadata = AttentionMetadata( - input_positions=positions, - block_tables=block_tables, - seq_lens=seq_lens, - query_start_loc=query_start_loc, - request_distribution=request_distribution, + self._run_compilation( + "eagle3_filter_token_and_prepare_initial_inputs", + filter_token_and_prepare_initial_inputs_wrapper, + token_indices, + query_start_loc, + seq_lens, + input_ids, + aux_hidden_states, + attention_metadata, + next_token_ids, + device_array( + self.runner.mesh, + np.asarray([self.runner.input_batch.num_reqs], + dtype=jnp.int32)), + num_tokens=num_tokens, ) - target_hidden_states = self._create_dummy_tensor( - (num_tokens, hidden_size), dtype, - NamedSharding(self.runner.mesh, PartitionSpec(None, "model"))) def draft_model_fn_wrapper( state, @@ -590,8 +568,11 @@ def draft_model_fn_wrapper( self.runner.kv_caches = kv_caches return hidden_states + target_hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), dtype, + NamedSharding(self.runner.mesh, PartitionSpec(None, "model"))) self._run_compilation( - "draft_model_fn", + "eagle3_draft_model_fn", draft_model_fn_wrapper, self.runner.drafter.state, self.runner.kv_caches, @@ -600,6 +581,22 @@ def draft_model_fn_wrapper( attention_metadata, num_tokens=num_tokens, ) + target_token_ids = self._create_dummy_tensor((num_tokens, ), + jnp.int32) + + self._run_compilation( + "eagle3_prepare_hidden_states_and_input_ids", + self.runner.drafter._prepare_hidden_states_and_input_ids, + aux_hidden_states, + query_start_loc, + target_token_ids, + next_token_ids, + device_array( + self.runner.mesh, + np.asarray([self.runner.input_batch.num_reqs], + dtype=jnp.int32)), + num_tokens=num_tokens, + ) attention_metadata.query_start_loc = jax.device_put( attention_metadata.query_start_loc, @@ -618,25 +615,26 @@ def draft_model_fn_wrapper( num_tokens=num_tokens, ) - target_hidden_states = self._create_dummy_tensor( - (num_tokens, draft_hidden_size), dtype) + hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), jnp.bfloat16, + NamedSharding(self.runner.mesh, PartitionSpec(None, None))) + self._run_compilation( - "draft_model_combine_hidden_states_fn", - self.runner.drafter.combine_hidden_states_fn, - self.runner.drafter.state, - target_hidden_states, + "eagle3_select_inputs_for_loop_speculation", + self.runner.drafter._select_inputs_for_loop_speculation, + positions, + hidden_states, + hidden_states, + last_token_indices, num_tokens=num_tokens, ) - target_token_ids = self._create_dummy_tensor((num_tokens, ), - jnp.int32) self._run_compilation( - "_prepare_input_ids", - self.runner.drafter._prepare_input_ids, - query_start_loc, - target_token_ids, - next_token_ids, - self.runner.input_batch.num_reqs, + "eagle3_select_draft_token_ids", + self.runner.drafter._select_draft_token_ids, + hidden_states, + last_token_indices, + num_tokens=num_tokens, ) def _precompile_structured_decoding(self) -> None: diff --git a/tpu_inference/runner/speculative_decoding_manager.py b/tpu_inference/runner/speculative_decoding_manager.py index 78e682a01..9dbf060d7 100644 --- a/tpu_inference/runner/speculative_decoding_manager.py +++ b/tpu_inference/runner/speculative_decoding_manager.py @@ -127,21 +127,25 @@ def propose_eagle3_draft_token_ids( self.runner.mesh, np.array(num_rejected_tokens, dtype=jnp.int32)) - attn_metadata, target_token_ids, target_hidden_states = self.runner.drafter.prepare_inputs( + target_hidden_states, input_ids, last_token_indices, attn_metadata = self.runner.drafter.prepare_inputs( attn_metadata, input_ids, aux_hidden_states, + next_token_ids, num_rejected_tokens, ) + self.runner.kv_caches, draft_token_ids = self.runner.drafter.propose( kv_caches=self.runner.kv_caches, - next_token_ids=next_token_ids, + input_ids=input_ids, attn_metadata=attn_metadata, - target_token_ids=target_token_ids, + last_token_indices=last_token_indices, target_hidden_states=target_hidden_states, ) - result = draft_token_ids.tolist() - return result + draft_token_ids = np.array(draft_token_ids) + if draft_token_ids.ndim == 1: + draft_token_ids = np.expand_dims(draft_token_ids, axis=-1) + return draft_token_ids.tolist() def get_spec_decode_metadata( self, diff --git a/tpu_inference/spec_decode/jax/eagle3.py b/tpu_inference/spec_decode/jax/eagle3.py index 1177e358b..86a460943 100644 --- a/tpu_inference/spec_decode/jax/eagle3.py +++ b/tpu_inference/spec_decode/jax/eagle3.py @@ -55,22 +55,10 @@ def load_model(self, target_model: Any) -> None: self.state.model.embed_tokens = target_model.model.embed @functools.partial(jax.jit, static_argnums=(0, )) - def _concate_hidden_states(self, aux_hidden_states): - """JIT-compiled helper for concatenating auxiliary hidden states.""" - # Concat aux hidden states along feature dim. - return jnp.concatenate(aux_hidden_states, axis=-1) - - @functools.partial(jax.jit, static_argnums=(0, )) - def _select_target_hidden_states(self, aux_hidden_states, token_indices): - """JIT-compiled helper for selecting target hidden states.""" - return jnp.concatenate([h[token_indices] for h in aux_hidden_states], - axis=-1) - - @functools.partial(jax.jit, static_argnums=(0, )) - def _prepare_input_ids(self, query_start_loc: jax.Array, - target_token_ids: jax.Array, - next_token_ids: jax.Array, - num_reqs: int) -> tuple[jnp.ndarray, jnp.ndarray]: + def _prepare_input_ids( + self, query_start_loc: jax.Array, target_token_ids: jax.Array, + next_token_ids: jax.Array, + num_reqs: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]: """JIT-compiled helper for preparing the input IDs for the draft model.""" last_token_indices = query_start_loc[1:] - 1 @@ -93,7 +81,10 @@ def _prepare_input_ids(self, query_start_loc: jax.Array, return input_ids, last_token_indices @functools.partial(jax.jit, static_argnums=(0, )) - def _prepare_input_loop(self, positions, seq_lens, block_tables): + def _update_inputs_for_loop_speculation( + self, positions: jax.Array, seq_lens: jax.Array, + block_tables: jax.Array + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: """JIT-compiled helper for preparing inputs in the loop of prediction.""" positions += 1 @@ -121,29 +112,39 @@ def _prepare_input_loop(self, positions, seq_lens, block_tables): return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables - @functools.partial(jax.jit, static_argnums=(0, )) - def _reshape_block_tables(self, block_tables: jax.Array) -> jax.Array: - """JIT-compiled helper for reshaping block tables.""" - return block_tables.reshape(-1) - - @functools.partial(jax.jit, static_argnums=(0, )) - def _get_draft_token_ids(self, logits: jax.Array) -> jnp.ndarray: - """JIT-compiled helper for getting draft token IDs from logits.""" - return jnp.argmax(logits, axis=-1) - @functools.partial(jax.jit, static_argnums=(0, )) def _stack_draft_token_ids( self, draft_token_ids_list: list[jax.Array]) -> jnp.ndarray: """JIT-compiled helper for stacking draft token IDs.""" return jnp.stack(draft_token_ids_list, axis=1) + @functools.partial(jax.jit, static_argnums=(0, )) + def _prepare_hidden_states_and_input_ids( + self, + aux_hidden_states: tuple[jax.Array, ...], + query_start_loc: jax.Array, + target_token_ids: jax.Array, + next_token_ids: jax.Array, + num_reqs: jax.Array, + ) -> tuple[jax.Array, jax.Array, jax.Array]: + target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1) + target_hidden_states = self.combine_hidden_states_fn( + self.state, target_hidden_states) + + input_ids, last_token_indices = self._prepare_input_ids( + query_start_loc, target_token_ids, next_token_ids, num_reqs) + # NOTE(pooyam): For now, we don't support multimodal. + + return target_hidden_states, input_ids, last_token_indices + def prepare_inputs( self, attn_metadata: AttentionMetadata, input_ids: jax.Array, aux_hidden_states: tuple[jax.Array, ...], + next_token_ids: jax.Array, num_rejected_tokens: Optional[jax.Array] = None, - ) -> tuple[AttentionMetadata, jnp.ndarray, jnp.ndarray]: + ) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]: """Prepare drafter inputs based on target forward outputs. Mirrors the GPU reference logic but adapted to TPU/JAX types: @@ -159,13 +160,26 @@ def prepare_inputs( assert aux_hidden_states is not None and len(aux_hidden_states) > 0, ( "EAGLE3 requires auxiliary hidden states from the target model.") - if num_rejected_tokens is None: - return attn_metadata, input_ids, self._concate_hidden_states( - aux_hidden_states) - + # The last KV cache group is for the draft model. + num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups) + draft_kv_cache_group_id = num_kv_cache_groups - 1 + block_tables = self.runner.input_batch.block_table[ + draft_kv_cache_group_id].get_cpu_tensor().reshape(-1) # Number of active requests in this step (un-padded count). num_reqs = self.runner.input_batch.num_reqs + if num_rejected_tokens is None: + num_reqs = device_array(self.mesh, + np.asarray([num_reqs], dtype=jnp.int32)) + # block_tables = device_array(self.mesh, block_tables) + attn_metadata = replace(attn_metadata, + block_tables=device_array( + self.mesh, block_tables)) + target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids( + aux_hidden_states, attn_metadata.query_start_loc, input_ids, + next_token_ids, num_reqs) + return target_hidden_states, input_ids, last_token_indices, attn_metadata + # Host copies from the metadata prepared by the runner. query_start_loc_cpu = attn_metadata.query_start_loc_cpu seq_lens_cpu = attn_metadata.seq_lens_cpu @@ -217,43 +231,85 @@ def prepare_inputs( token_indices_cpu = np.pad(token_indices_cpu, (0, pad_width), "constant", constant_values=0) - token_indices = jnp.asarray(token_indices_cpu, dtype=jnp.int32) + # Update seq_lens for active requests only: new_seq_lens = s - n. + new_seq_lens_cpu = seq_lens_cpu - nrt_cpu + + query_start_loc, seq_lens, token_indices, num_reqs, block_tables = device_array( + self.mesh, + (new_query_start_loc_cpu, new_seq_lens_cpu, token_indices_cpu, + np.asarray([num_reqs], dtype=jnp.int32), block_tables)) + + attn_metadata = replace(attn_metadata, block_tables=block_tables) + return self._filter_token_and_prepare_initial_inputs( + token_indices, query_start_loc, seq_lens, input_ids, + aux_hidden_states, attn_metadata, next_token_ids, num_reqs) + + @functools.partial(jax.jit, static_argnums=(0, )) + def _filter_token_and_prepare_initial_inputs( + self, + token_indices: jax.Array, + query_start_loc: jax.Array, + seq_lens: jax.Array, + input_ids: jax.Array, + aux_hidden_states: tuple[jax.Array, ...], + attn_metadata: AttentionMetadata, + next_token_ids: jax.Array, + num_reqs: jax.Array, + ) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]: + # Select tokens and hidden states. - target_token_ids = self.runner._select_from_array_fn( - input_ids, token_indices) - target_hidden_states = self._select_target_hidden_states( - aux_hidden_states, token_indices) + target_token_ids = input_ids[token_indices] # Update positions to match the selected tokens. if attn_metadata.input_positions.ndim == 2: input_positions = attn_metadata.input_positions[:, token_indices] else: - input_positions = self.runner._select_from_array_fn( - attn_metadata.input_positions, token_indices) - - # Update seq_lens for active requests only: new_seq_lens = s - n. - new_seq_lens_cpu = seq_lens_cpu - nrt_cpu + input_positions = attn_metadata.input_positions[token_indices] - query_start_loc, seq_lens = device_array(self.mesh, ( - new_query_start_loc_cpu, - new_seq_lens_cpu, - )) - - # Return updated metadata with positions, qsl, and seq_lens. - updated_attn = AttentionMetadata( + attn_metadata = AttentionMetadata( input_positions=input_positions, block_tables=attn_metadata.block_tables, seq_lens=seq_lens, query_start_loc=query_start_loc, request_distribution=attn_metadata.request_distribution, ) - return updated_attn, target_token_ids, target_hidden_states + + target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids( + [h[token_indices] for h in aux_hidden_states], query_start_loc, + target_token_ids, next_token_ids, num_reqs) + + return target_hidden_states, input_ids, last_token_indices, attn_metadata + + @functools.partial(jax.jit, static_argnums=(0, )) + def _select_draft_token_ids( + self, + hidden_states: jax.Array, + last_token_indices: jax.Array, + ) -> jax.Array: + sample_hidden_states = hidden_states[last_token_indices] + return self._get_draft_token_ids(sample_hidden_states) + + @functools.partial(jax.jit, static_argnums=(0, )) + def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array: + lora_metadata = None + logits = self.compute_logits_fn(self.state, hidden_states, + lora_metadata) + return jnp.argmax(logits, axis=-1) + + @functools.partial(jax.jit, static_argnums=(0, )) + def _select_inputs_for_loop_speculation( + self, positions: jax.Array, residual: jax.Array, + hidden_states: jax.Array, + last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]: + return positions[last_token_indices], residual[ + last_token_indices], self._select_draft_token_ids( + hidden_states, last_token_indices) def propose( self, kv_caches: list[jax.Array], - next_token_ids: jnp.ndarray, # [batch_size] + input_ids: jax.Array, attn_metadata: AttentionMetadata, - target_token_ids, + last_token_indices, target_hidden_states, ) -> tuple[list[jax.Array], jnp.ndarray]: """Proposes draft tokens using the draft model. @@ -262,22 +318,7 @@ def propose( draft token IDs. """ - target_hidden_states = self.combine_hidden_states_fn( - self.state, target_hidden_states) - - input_ids, last_token_indices = self._prepare_input_ids( - attn_metadata.query_start_loc, target_token_ids, next_token_ids, - self.runner.input_batch.num_reqs) - # NOTE(pooyam): For now, we don't support multimodal. - - # The last KV cache group is for the draft model. - num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups) - draft_kv_cache_group_id = num_kv_cache_groups - 1 - block_tables = self.runner.input_batch.block_table[ - draft_kv_cache_group_id].get_device_tensor() - block_tables = self._reshape_block_tables(block_tables) - attn_metadata = replace(attn_metadata, block_tables=block_tables) - + # input_ids and target_hidden_states for the first speculation have been prepared in prepare_inputs() to improve performance. kv_caches, hidden_states, residual = self.model_fn( self.state, kv_caches, @@ -285,27 +326,21 @@ def propose( target_hidden_states, attn_metadata, ) - sample_hidden_states = self.runner._select_from_array_fn( - hidden_states, last_token_indices) - lora_metadata = None - logits = self.compute_logits_fn(self.state, sample_hidden_states, - lora_metadata) - draft_token_ids = self._get_draft_token_ids(logits) - draft_token_ids_list = [draft_token_ids] - # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - return kv_caches, self._stack_draft_token_ids(draft_token_ids_list) + return kv_caches, self._select_draft_token_ids( + hidden_states, last_token_indices) - positions = self.runner._select_from_array_fn( - attn_metadata.input_positions, last_token_indices) - hidden_states = self.runner._select_from_array_fn( - residual[0], last_token_indices) + positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation( + attn_metadata.input_positions, residual[0], hidden_states, + last_token_indices) + + draft_token_ids_list = [draft_token_ids] for _ in range(self.num_speculative_tokens - 1): input_ids_loop = draft_token_ids_list[-1] - positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables = self._prepare_input_loop( + positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables = self._update_inputs_for_loop_speculation( positions, attn_metadata.seq_lens, attn_metadata.block_tables) attn_metadata = replace( @@ -323,9 +358,7 @@ def propose( attn_metadata, ) hidden_states = residual[0] - logits = self.compute_logits_fn(self.state, new_hidden_states, - lora_metadata) - draft_token_ids = self._get_draft_token_ids(logits) + draft_token_ids = self._get_draft_token_ids(new_hidden_states) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens]