Conversation
533186d to
339b296
Compare
There was a problem hiding this comment.
Pull request overview
Adds support for multi-token prediction (MTP) speculative decoding with num_speculative_tokens > 1, including scheduler/block allocation updates and per-step attention metadata setup during draft proposal.
Changes:
- Allow
num_speculative_tokensvalues in1..3via config validation update. - Allocate extra KV-cache blocks up-front to cover additional MTP draft tokens.
- Update Eagle speculative proposer to update attention metadata during iterative draft generation and pass
ScheduledBatchthrough runner APIs.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/spec_decode/eagle.py | Adds batch-aware MTP loop logic and rebuilds attn metadata for subsequent draft steps |
| atom/model_engine/scheduler.py | Allocates extra blocks to support MTP draft tokens |
| atom/model_engine/model_runner.py | Adjusts draft token selection indexing and passes batch into proposer |
| atom/model_engine/block_manager.py | Updates allocation/append logic for extra tokens; removes prefix-caching behavior |
| atom/config.py | Expands allowed speculative token count and updates error message |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| slot_mapping = [ | ||
| block_table[pos // self.runner.block_size] * self.runner.block_size | ||
| + (pos % self.runner.block_size) | ||
| for block_table, seq_len in zip(batch.block_tables, context_lens) |
There was a problem hiding this comment.
context_lens is a torch tensor, so seq_len elements yielded by zip(batch.block_tables, context_lens) are tensor scalars. Using those tensor scalars inside range(...) will raise a TypeError (range expects Python ints). Convert context_lens to a Python list/ints first (e.g., iterate over context_lens.tolist() or use seq_len.item()), and ensure pos computations use Python ints.
| slot_mapping = [ | |
| block_table[pos // self.runner.block_size] * self.runner.block_size | |
| + (pos % self.runner.block_size) | |
| for block_table, seq_len in zip(batch.block_tables, context_lens) | |
| # Convert context_lens tensor to a list of Python ints so that seq_len and pos | |
| # used below are plain integers (required by range and indexing). | |
| context_lens_list = context_lens.tolist() | |
| slot_mapping = [ | |
| block_table[pos // self.runner.block_size] * self.runner.block_size | |
| + (pos % self.runner.block_size) | |
| for block_table, seq_len in zip(batch.block_tables, context_lens_list) |
atom/spec_decode/eagle.py
Outdated
| num_blocks_per_seq = [ | ||
| (ctx + block_size - 1) // block_size | ||
| for ctx in context_lens.tolist() | ||
| ] | ||
| kv_indptr = np.cumsum(num_blocks_per_seq) | ||
| sum_blocks = kv_indptr[-1] | ||
| sum_blocks_before_converted = sum( | ||
| [(i + block_ratio - 1) // block_ratio for i in num_blocks_per_seq] | ||
| ) | ||
|
|
||
| def prepare_kv_indices(): | ||
| var["kv_indices"].np[:sum_blocks_before_converted] = np.fromiter( | ||
| itertools.chain.from_iterable(batch.block_tables), | ||
| dtype=np.int32, | ||
| count=sum_blocks_before_converted, |
There was a problem hiding this comment.
kv_indptr/sum_blocks are computed from num_blocks_per_seq (unconverted), but kv_indices is only populated up to sum_blocks_before_converted (converted by block_ratio). Later, kv_indices is copied to GPU with length sum_blocks, which can leave a tail uninitialized or make kv_indptr inconsistent with the number of indices. Align the kv_indices fill/count with the same block accounting used by kv_indptr (either both converted or both unconverted).
| num_blocks_per_seq = [ | |
| (ctx + block_size - 1) // block_size | |
| for ctx in context_lens.tolist() | |
| ] | |
| kv_indptr = np.cumsum(num_blocks_per_seq) | |
| sum_blocks = kv_indptr[-1] | |
| sum_blocks_before_converted = sum( | |
| [(i + block_ratio - 1) // block_ratio for i in num_blocks_per_seq] | |
| ) | |
| def prepare_kv_indices(): | |
| var["kv_indices"].np[:sum_blocks_before_converted] = np.fromiter( | |
| itertools.chain.from_iterable(batch.block_tables), | |
| dtype=np.int32, | |
| count=sum_blocks_before_converted, | |
| # Compute number of KV blocks per sequence at block_size, | |
| # then convert to the effective block count according to block_ratio | |
| num_blocks_per_seq = [ | |
| (ctx + block_size - 1) // block_size | |
| for ctx in context_lens.tolist() | |
| ] | |
| num_blocks_per_seq_converted = [ | |
| (n + block_ratio - 1) // block_ratio for n in num_blocks_per_seq | |
| ] | |
| kv_indptr = np.cumsum(num_blocks_per_seq_converted) | |
| sum_blocks = kv_indptr[-1] | |
| def prepare_kv_indices(): | |
| var["kv_indices"].np[:sum_blocks] = np.fromiter( | |
| itertools.chain.from_iterable(batch.block_tables), | |
| dtype=np.int32, | |
| count=sum_blocks, |
atom/spec_decode/eagle.py
Outdated
| var["kv_indptr"].np[1 : scheduled_bs + 1] = kv_indptr | ||
| var["kv_indptr"].np[scheduled_bs + 1 : bs + 1] = sum_blocks | ||
| var["kv_last_page_lens"].np[:scheduled_bs] = ( | ||
| batch.last_block_num_tokens if self.block_size != 1 else 1 |
There was a problem hiding this comment.
This references self.block_size, but in this hunk the block size is accessed via self.runner.block_size (and a local block_size is also defined above). If self.block_size is not defined on this class, this will raise AttributeError. Use the same source (block_size local or self.runner.block_size) consistently.
| batch.last_block_num_tokens if self.block_size != 1 else 1 | |
| batch.last_block_num_tokens if block_size != 1 else 1 |
| num_additional_tokens = self.mtp_k - 1 | ||
| self.block_manager.allocate(seq, num_additional_tokens) |
There was a problem hiding this comment.
Scheduler now requests allocation for additional MTP tokens, but BlockManager.can_allocate() still checks only seq.num_blocks. If scheduling relies on can_allocate() to gate allocations, the check can under-estimate required blocks and allocate() may index into an empty free_block_ids list. Update can_allocate() (or the scheduling check) to account for num_additional_tokens the same way allocate() computes num_blocks.
|
|
||
| final_len = seq_len + num_new_tokens - 1 | ||
| needed_blocks = (final_len + self.block_size - 1) // self.block_size | ||
|
|
||
| while len(block_table) < needed_blocks: | ||
| block_id = self.free_block_ids[0] | ||
| self._allocate_block(block_id) | ||
| block_table.append(block_id) |
There was a problem hiding this comment.
may_append() appears intended to ensure capacity for num_new_tokens. Using final_len = seq_len + num_new_tokens - 1 under-allocates at boundaries (e.g., when seq_len % block_size == 0 and num_new_tokens == 1, it will not allocate a new block even though appending one token requires it). Compute needed blocks based on seq_len + num_new_tokens (or an equivalent inclusive formula) so the next append doesn’t overflow the current block table.
| # else -1 | ||
| # ) | ||
| # block_id = ( | ||
| # self.hash_to_block_id.get(h, -1) if self.enable_prefix_caching else -1 | ||
| # ) | ||
| # if block_id == -1 or self.blocks[block_id].token_ids != token_ids: | ||
| # cache_miss = True | ||
| # if cache_miss: | ||
| # block_id = self.free_block_ids[0] | ||
| # block = self._allocate_block(block_id) | ||
| # else: | ||
| # seq.num_cached_tokens += self.block_size | ||
| # if block_id in self.used_block_ids: | ||
| # block = self.blocks[block_id] | ||
| # block.ref_count += 1 | ||
| # else: | ||
| # block = self._allocate_block(block_id) | ||
| # if h != -1: | ||
| # block.update(h, token_ids) | ||
| # self.hash_to_block_id[h] = block_id | ||
| # seq.block_table.append(block_id) |
There was a problem hiding this comment.
A large, commented-out prior implementation is left in the file. This makes the current allocation behavior harder to audit and increases maintenance burden. Consider removing the commented block (git history preserves it) or replacing it with a short explanatory comment describing why prefix caching was removed/disabled for now.
| # else -1 | |
| # ) | |
| # block_id = ( | |
| # self.hash_to_block_id.get(h, -1) if self.enable_prefix_caching else -1 | |
| # ) | |
| # if block_id == -1 or self.blocks[block_id].token_ids != token_ids: | |
| # cache_miss = True | |
| # if cache_miss: | |
| # block_id = self.free_block_ids[0] | |
| # block = self._allocate_block(block_id) | |
| # else: | |
| # seq.num_cached_tokens += self.block_size | |
| # if block_id in self.used_block_ids: | |
| # block = self.blocks[block_id] | |
| # block.ref_count += 1 | |
| # else: | |
| # block = self._allocate_block(block_id) | |
| # if h != -1: | |
| # block.update(h, token_ids) | |
| # self.hash_to_block_id[h] = block_id | |
| # seq.block_table.append(block_id) | |
| # NOTE: A previous prefix-caching-based allocation implementation was removed | |
| # to simplify auditing of the current behavior. Refer to git history if the | |
| # legacy prefix caching logic needs to be revisited or restored. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Hash remains -1 until block is full (consistent with allocate logic) | ||
| assert last_block.hash == -1, last_block.block_id | ||
|
|
||
| final_len = seq_len + num_new_tokens - 1 |
There was a problem hiding this comment.
final_len = seq_len + num_new_tokens - 1 under-allocates when seq_len % block_size == 0 (e.g., appending 1 token at a block boundary will not allocate a new block). Compute the required length as seq_len + num_new_tokens (no - 1) so boundary cases correctly trigger new block allocation.
| final_len = seq_len + num_new_tokens - 1 | |
| final_len = seq_len + num_new_tokens |
| def can_allocate(self, seq: Sequence) -> bool: | ||
| return len(self.free_block_ids) >= seq.num_blocks |
There was a problem hiding this comment.
allocate() now allocates based on seq.num_tokens + num_additional_tokens, but can_allocate() still checks only seq.num_blocks. This can allow scheduling to proceed even when there aren’t enough free blocks for the extra speculative tokens, leading to self.free_block_ids[0] failures inside allocate(). Update can_allocate() to account for num_additional_tokens (or accept it as a parameter) and ensure scheduling uses the same calculation.
| def can_allocate(self, seq: Sequence) -> bool: | |
| return len(self.free_block_ids) >= seq.num_blocks | |
| def can_allocate(self, seq: Sequence, num_additional_tokens: int = 0) -> bool: | |
| """ | |
| Check whether there are enough free blocks to allocate for the sequence, | |
| optionally accounting for additional (e.g., speculative) tokens. | |
| """ | |
| total_tokens = seq.num_tokens + max(0, num_additional_tokens) | |
| num_blocks = (total_tokens + seq.block_size - 1) // seq.block_size | |
| return len(self.free_block_ids) >= num_blocks |
| slot_mapping = [ | ||
| block_table[pos // self.runner.block_size] * self.runner.block_size | ||
| + (pos % self.runner.block_size) | ||
| for block_table, seq_len in zip(batch.block_tables, context_lens) |
There was a problem hiding this comment.
Iterating context_lens (a torch.Tensor) yields 0-d tensors, so range(seq_len - max_seqlen_q, seq_len) will raise a TypeError unless seq_len is converted to a Python int (e.g., seq_len.item() or using a precomputed context_lens_list = context_lens.tolist()). Convert seq_len to an int before using it in range().
| slot_mapping = [ | |
| block_table[pos // self.runner.block_size] * self.runner.block_size | |
| + (pos % self.runner.block_size) | |
| for block_table, seq_len in zip(batch.block_tables, context_lens) | |
| context_lens_list = context_lens.tolist() | |
| slot_mapping = [ | |
| block_table[pos // self.runner.block_size] * self.runner.block_size | |
| + (pos % self.runner.block_size) | |
| for block_table, seq_len in zip(batch.block_tables, context_lens_list) |
| num_blocks_per_seq = [ | ||
| (ctx + block_size - 1) // block_size | ||
| for ctx in context_lens.tolist() | ||
| ] | ||
| kv_indptr = np.cumsum(num_blocks_per_seq) | ||
| sum_blocks = kv_indptr[-1] |
There was a problem hiding this comment.
kv_indptr/sum_blocks are computed from num_blocks_per_seq (based on context_lens), but prepare_kv_indices() copies len(bt) for each sequence. If batch.block_tables include preallocated extra blocks (which this PR introduces for MTP), len(bt) can exceed num_blocks_per_seq, causing kv_indices content to disagree with kv_indptr and the sum_blocks slice used later. Copy only the first num_blocks_per_seq[i] entries per sequence (and keep offset consistent with that), so kv_indices and kv_indptr describe the same packed structure.
| for bt in batch.block_tables: | ||
| n = len(bt) | ||
| dst[offset : offset + n] = bt |
There was a problem hiding this comment.
kv_indptr/sum_blocks are computed from num_blocks_per_seq (based on context_lens), but prepare_kv_indices() copies len(bt) for each sequence. If batch.block_tables include preallocated extra blocks (which this PR introduces for MTP), len(bt) can exceed num_blocks_per_seq, causing kv_indices content to disagree with kv_indptr and the sum_blocks slice used later. Copy only the first num_blocks_per_seq[i] entries per sequence (and keep offset consistent with that), so kv_indices and kv_indptr describe the same packed structure.
| for bt in batch.block_tables: | |
| n = len(bt) | |
| dst[offset : offset + n] = bt | |
| for i, bt in enumerate(batch.block_tables): | |
| n = num_blocks_per_seq[i] | |
| dst[offset : offset + n] = bt[:n] |
| if batch.is_dummy_run: | ||
| return draft_token_ids |
There was a problem hiding this comment.
Returning early on is_dummy_run exits after filling only the current i column, leaving later draft positions uninitialized for mtp_k > 1. If callers expect a fully shaped/filled [batch, mtp_k] result even in dummy runs, this will produce incorrect outputs. Consider moving the dummy-run guard outside the loop (or ensuring remaining columns are deterministically filled) while still skipping the expensive metadata rebuilds.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Hash remains -1 until block is full (consistent with allocate logic) | ||
| assert last_block.hash == -1, last_block.block_id | ||
|
|
||
| final_len = seq_len + num_new_tokens - 1 |
There was a problem hiding this comment.
final_len = seq_len + num_new_tokens - 1 under-allocates blocks when appending across a block boundary (e.g., seq_len exactly divisible by block_size and num_new_tokens=1 yields no new block). Compute needed blocks from seq_len + num_new_tokens (without - 1) so the KV cache grows correctly.
| final_len = seq_len + num_new_tokens - 1 | |
| final_len = seq_len + num_new_tokens |
| context_lens = positions + 1 | ||
| slot_mapping = [ | ||
| block_table[pos // self.runner.block_size] * self.runner.block_size | ||
| + (pos % self.runner.block_size) | ||
| for block_table, seq_len in zip(batch.block_tables, context_lens) | ||
| for pos in range(seq_len - max_seqlen_q, seq_len) | ||
| ] |
There was a problem hiding this comment.
context_lens is a torch tensor, so iterating it yields 0-d tensors; using those in range(...) will raise a TypeError because range() expects Python ints. Convert sequence lengths to Python ints first (e.g., iterate over context_lens.tolist() or use int(seq_len.item())) before calling range().
| slot_mapping = [ | ||
| block_table[pos // self.runner.block_size] * self.runner.block_size | ||
| + (pos % self.runner.block_size) | ||
| for block_table, seq_len in zip(batch.block_tables, context_lens) | ||
| for pos in range(seq_len - max_seqlen_q, seq_len) |
There was a problem hiding this comment.
This builds slot_mapping via nested Python loops on every MTP step, which can become a hot-path cost for larger batches. If possible, consider precomputing per-sequence last-slot indices using vectorized ops (e.g., compute last positions with tensor ops and gather from block tables), then fill the host buffer without per-token Python iteration.
| slot_mapping = [ | |
| block_table[pos // self.runner.block_size] * self.runner.block_size | |
| + (pos % self.runner.block_size) | |
| for block_table, seq_len in zip(batch.block_tables, context_lens) | |
| for pos in range(seq_len - max_seqlen_q, seq_len) | |
| # With max_seqlen_q == 1, we only need the last position for each sequence. | |
| last_positions = context_lens - max_seqlen_q | |
| slot_mapping = [ | |
| block_table[pos // self.runner.block_size] * self.runner.block_size | |
| + (pos % self.runner.block_size) | |
| for block_table, pos in zip(batch.block_tables, last_positions) |
| # Compute the draft token ids. | ||
| # draft_token_indices: [ 1, 2, 3, 105, 106, 208] | ||
| draft_token_ids = torch.index_select(input_ids, 0, bonus_logits_indices) | ||
| draft_token_ids = torch.index_select(input_ids, 0, target_logits_indices + 1) |
There was a problem hiding this comment.
Adding + 1 to target_logits_indices can produce out-of-bounds indices if any element points to the last row in input_ids. If the intent is to select the next token relative to target_logits_indices, ensure the indices are constructed so they never reference the final element (or filter/clamp appropriately), otherwise this will crash at runtime.
| draft_token_ids = torch.index_select(input_ids, 0, target_logits_indices + 1) | |
| # Use the token immediately after each target logits index, clamped to the | |
| # last valid position to avoid out-of-bounds indices. | |
| next_token_indices = target_logits_indices + 1 | |
| max_valid_index = input_ids.size(0) - 1 | |
| next_token_indices = torch.clamp(next_token_indices, max=max_valid_index) | |
| draft_token_ids = torch.index_select(input_ids, 0, next_token_indices) |
| self.total_draft_tokens += self.mtp_k | ||
| self.total_accepted_tokens += num_accepted_tokens - self.mtp_k | ||
| self.total_accepted_tokens += max(0, num_accepted_tokens - 1) |
There was a problem hiding this comment.
With the updated accounting, total_draft_tokens is incremented by mtp_k per step (and the comment above notes the last token is not a draft). This makes total_draft_tokens / total_accepted_tokens naming misleading (they no longer clearly represent “draft tokens” and “accepted tokens” in the literal sense). Consider renaming these counters to reflect what they measure now (e.g., total_spec_tokens_per_step / total_draft_tokens_accepted) to avoid confusion in downstream logging/metrics.
| stats = { | ||
| "total_draft_tokens": self.scheduler.total_draft_tokens, | ||
| "total_accepted_tokens": self.scheduler.total_accepted_tokens, | ||
| "acceptance_rate": avg_tokens_per_step, | ||
| } |
There was a problem hiding this comment.
The stats key acceptance_rate now contains avg_tokens_per_step, and the log line prints Avg tokens/step but then labels it as (acceptance rate). This is confusing/misleading output. Use a dedicated key like avg_tokens_per_step (and/or log both avg tokens/step and acceptance rate explicitly) and fix the label to match the metric.
| logger.info( | ||
| f" Avg tokens/step: {stats['acceptance_rate']:.2f} " | ||
| f"(acceptance rate)" | ||
| ) |
There was a problem hiding this comment.
The stats key acceptance_rate now contains avg_tokens_per_step, and the log line prints Avg tokens/step but then labels it as (acceptance rate). This is confusing/misleading output. Use a dedicated key like avg_tokens_per_step (and/or log both avg tokens/step and acceptance rate explicitly) and fix the label to match the metric.
ChuanLi1101
left a comment
There was a problem hiding this comment.
Left some comments, I think co-pilot captured more than I didn't. Please consider revise and resubmit?
| @@ -65,35 +65,48 @@ def _deallocate_block(self, block_id: int): | |||
| def can_allocate(self, seq: Sequence) -> bool: | |||
| return len(self.free_block_ids) >= seq.num_blocks | |||
There was a problem hiding this comment.
can_allocate() 没有同步更新。allocate() 现在基于 num_tokens + num_additional_tokens 计算所需 block 数,但 can_allocate() 仍只检查 seq.num_blocks。当 free blocks 不够时 scheduler 仍会放行,这个可能导致 allocate() 内 self.free_block_ids[0] 访问空列表崩溃 。
| # Hash remains -1 until block is full (consistent with allocate logic) | ||
| assert last_block.hash == -1, last_block.block_id | ||
|
|
||
| final_len = seq_len + num_new_tokens - 1 |
There was a problem hiding this comment.
final_len = seq_len + num_new_tokens - 1 当 seq_len % block_size == 0 且 num_new_tokens == 1 时(block 边界追加 token),不会分配新 block,但实际会需要。
| context.positions = positions | ||
| context.is_prefill = False | ||
|
|
||
| context_lens = positions + 1 |
There was a problem hiding this comment.
context_lens = positions + 1 是 torch.Tensor。迭代时 seq_len 是 0-d tensor,range(seq_len - max_seqlen_q, seq_len) 会抛 TypeError(range 需要 Python int)
| [(i + block_ratio - 1) // block_ratio for i in num_blocks_per_seq] | ||
| ) | ||
|
|
||
| def prepare_kv_indices(): |
There was a problem hiding this comment.
prepare_kv_indices() 中 n = len(bt) 拷贝整个 block_table。但 allocate() 会预分配额外 block(num_additional_tokens),所以 len(bt) 可能 > num_blocks_per_seq[i],导致 kv_indices 与 kv_indptr 数据结构不一致
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist