Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 92 additions & 26 deletions tests/spec_decode/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,22 @@ def test_prepare_inputs():
proposer = _create_proposer("eagle3", 1)
num_reqs = 3
max_num_seqs = 128
max_num_blocks_per_req = 10 # Mock value

# Mock runner attributes
proposer.runner.input_batch.num_reqs = num_reqs
proposer.runner.num_tokens_paddings = runner_utils.get_token_paddings(
min_token_size=16, max_token_size=1024, padding_gap=0)
proposer.runner._select_from_array_fn = (
lambda array, indices: array[indices])

# Mocks required by _prepare_draft_inputs helper
proposer.combine_hidden_states_fn = lambda state, h: h # Mock passthrough
proposer.state = None # Mock state
proposer.runner.input_batch.block_table = [mock.MagicMock()]
# Mock the block table return value (2D array)
(proposer.runner.input_batch.block_table[0].get_device_tensor.return_value
) = jnp.zeros((num_reqs, max_num_blocks_per_req), dtype=jnp.int32)

# --- Setup sequence data ---
qsl_cpu = np.zeros(max_num_seqs + 1, dtype=np.int32)
query_lens = np.zeros(max_num_seqs, dtype=np.int32)
query_lens[:num_reqs] = [4, 7, 5]
Expand All @@ -102,12 +111,17 @@ def test_prepare_inputs():
num_rejected_tokens_cpu = np.zeros(max_num_seqs, dtype=np.int32)
num_rejected_tokens_cpu[:num_reqs] = [1, 3, 2]
num_rejected_tokens = jnp.array(num_rejected_tokens_cpu)
# This is only used in the _prepare_input_ids helper
# It must be padded to max_num_seqs (128) to match the mask in jnp.where
next_token_ids_cpu = np.zeros(max_num_seqs, dtype=np.int32)
next_token_ids_cpu[:num_reqs] = [1, 2, 3] # Valid tokens for active reqs
next_token_ids = jnp.array(next_token_ids_cpu)

attn_metadata = AttentionMetadata(
seq_lens=jnp.array(sl_cpu),
input_positions=jnp.arange(total_tokens),
query_start_loc=jnp.array(qsl_cpu),
block_tables=jnp.array([]),
block_tables=jnp.array([]), # This will be replaced by the mock
request_distribution=None,
)
attn_metadata.query_start_loc_cpu = qsl_cpu
Expand All @@ -129,20 +143,26 @@ def test_prepare_inputs():
expected_total_tokens = runner_utils.get_padded_token_len(
proposer.runner.num_tokens_paddings, expected_total_tokens)

expected_last_token_indices = jnp.array(expected_new_qsl[1:] - 1)

# Execute
updated_metadata, target_token_ids, target_hidden_states = (
target_hidden_states, input_ids, last_token_indices, updated_metadata = (
proposer.prepare_inputs(attn_metadata, input_ids, aux_hidden_states,
num_rejected_tokens))
next_token_ids, num_rejected_tokens))

# Assertions
assert jnp.array_equal(updated_metadata.query_start_loc,
jnp.array(expected_new_qsl))
assert jnp.array_equal(updated_metadata.seq_lens,
jnp.array(expected_new_seq_lens))
assert target_token_ids.shape == (expected_total_tokens, )

assert jnp.array_equal(last_token_indices, expected_last_token_indices)

assert input_ids.shape == (expected_total_tokens, )
# NOTE: We don't check the content of target_token_ids for padded requests
# as it's complicated to construct the expected tensor. The shape check
# and the qsl/seq_len checks are sufficient to validate the logic.
# The concatenated hidden state shape should be (..., hidden_size * 3)
assert target_hidden_states.shape == (expected_total_tokens,
hidden_size * 3)

Expand All @@ -161,29 +181,59 @@ def test_propose(method, num_speculative_tokens):
total_tokens = seq_len_1 + seq_len_2
base_token_ids = [42, 60]

def mock_model_fn(state, kv_caches, input_ids, hidden_states,
def mock_model_fn(state, kv_caches, input_ids, target_hidden_states,
attn_metadata):
"""
Mock model_fn.
Returns: (kv_caches, hidden_states_for_logits, residual_tuple)

- On first call (num_tokens == total_tokens):
Populate hidden_states_for_logits[last_token_indices] with base_token_ids.
Populate residual_tuple[0][last_token_indices] with base_token_ids.
- On loop calls (num_tokens == batch_size):
Use input_ids (previous draft token) to generate new token (input_ids + 1).
Populate hidden_states_for_logits with (input_ids + 1).
Populate residual_tuple[0] with (input_ids + 1).
"""
num_tokens = input_ids.shape[0]
new_hidden_states = jnp.zeros((num_tokens, hidden_size))

# This will be used for logits (output 2)
hidden_states_for_logits = jnp.zeros((num_tokens, hidden_size))
# This will be fed into the next step (output 3, item 0)
residual_hidden_states = jnp.zeros((num_tokens, hidden_size))

if num_tokens == total_tokens:
# First call in propose. Set hidden states for last tokens
# to produce the first draft tokens.
# First call in propose.
# `propose` will select from last_token_indices.
last_token_indices = attn_metadata.query_start_loc[1:] - 1
# The proposer uses next_token_ids to set the last token of each
# sequence in input_ids. We mock this behavior by directly using
# next_token_ids to generate the first draft tokens.
# The mock `compute_logits` will use hidden_states[:, 0]
# to generate tokens.
new_hidden_states = new_hidden_states.at[

# Set logits output
hidden_states_for_logits = hidden_states_for_logits.at[
last_token_indices, 0].set(jnp.array(base_token_ids))
else: # Subsequent calls in the loop
new_hidden_states = new_hidden_states.at[:, 0].set(input_ids + 1)

return kv_caches, new_hidden_states, new_hidden_states
# Set residual for next step
residual_hidden_states = residual_hidden_states.at[
last_token_indices, 0].set(jnp.array(base_token_ids))
else:
# Subsequent calls in the loop
# input_ids is the previous draft token (shape `batch_size`)
# Mock logic: next token = previous token + 1
next_token_ids_encoded = input_ids + 1

# Set logits output
hidden_states_for_logits = hidden_states_for_logits.at[:, 0].set(
next_token_ids_encoded)

# Set residual for next step
residual_hidden_states = residual_hidden_states.at[:, 0].set(
next_token_ids_encoded)

# Return (kv_caches, hidden_states, residual_tuple)
return kv_caches, hidden_states_for_logits, (residual_hidden_states, )

def mock_compute_logits_fn(state, hidden_states, lora_metadata):
# Create deterministic logits from hidden_states.
# Takes the value from hidden_states[:, 0]
token_ids = hidden_states[:, 0].astype(jnp.int32)
return jax.nn.one_hot(token_ids, vocab_size)

Expand All @@ -198,39 +248,55 @@ def mock_combine_hidden_states_fn(state, hidden_states):

# Inputs
kv_caches = [None] * 1 # Mock kv_caches
next_token_ids = jnp.array(base_token_ids, dtype=jnp.int32)

# Create the 2D table first, as this is what the (unused) mock expects
block_tables_2d = jnp.zeros((batch_size, 10), dtype=jnp.int32)

attn_metadata = AttentionMetadata(
seq_lens=jnp.array([seq_len_1, seq_len_2]),
input_positions=jnp.concatenate(
[jnp.arange(seq_len_1),
jnp.arange(seq_len_2)]),
query_start_loc=jnp.array([0, seq_len_1, total_tokens]),
block_tables=jnp.zeros((2, 10), dtype=jnp.int32),
# Pass the FLATTENED table to simulate output of prepare_inputs
block_tables=block_tables_2d.reshape(-1),
request_distribution=None,
)

# These are the inputs to `propose`
# input_ids (from prepare_inputs)
target_token_ids = jnp.zeros(total_tokens, dtype=jnp.int32)
# target_hidden_states (from prepare_inputs)
target_hidden_states = jnp.zeros((total_tokens, hidden_size))
# last_token_indices (from prepare_inputs)
last_token_indices = attn_metadata.query_start_loc[1:] - 1

# Mock runner for block tables
# This mock isn't actually used by propose(), but we'll set it
# to the 2D table for correctness, as that's what
# _prepare_draft_inputs (called by prepare_inputs) would expect.
proposer.runner.input_batch.num_reqs = batch_size
proposer.runner.input_batch.block_table = [mock.MagicMock()]
(proposer.runner.input_batch.block_table[0].get_device_tensor.return_value
) = attn_metadata.block_tables
proposer.runner._select_from_array_fn = (
lambda array, indices: array[indices])
) = block_tables_2d

# Execute
_, draft_token_ids = proposer.propose(
kv_caches,
next_token_ids,
attn_metadata,
target_token_ids,
attn_metadata,
last_token_indices,
target_hidden_states,
)

# Assertions
assert draft_token_ids.shape == (batch_size, num_speculative_tokens)

# Check the generated tokens
# Step 0: base_token_ids [42, 60]
# Step 1: [43, 61]
# Step 2: [44, 62]
# ...
expected_tokens = np.zeros((batch_size, num_speculative_tokens),
dtype=np.int64)
for i in range(batch_size):
Expand Down
Loading