Rewrite attention sink from eviction to ring buffer (#18821)#18821
Rewrite attention sink from eviction to ring buffer (#18821)#18821kirklandsign wants to merge 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18821
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 18 Cancelled Jobs, 5 Unrelated FailuresAs of commit 973c88b with merge base 5e8a0df ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D100216687. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Replaces the eviction-based attention sink implementation with a torch.export-compatible ring-buffer KV cache design, updates the attention path to rely on ring-buffer masking, and rewrites the associated tests/configuration.
Changes:
- Update attention sink config parsing/validation to accept
"<sink_size>,<window_size>"(removing eviction batch size). - Implement ring-buffer-based attention sink KV cache + cache-position management, and adjust
AttentionMHA.forwardto use ring-buffer masking after KV updates. - Rewrite attention sink tests and add an example YAML config; update BUCK deps for the new test behavior.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/export/config/llm_config.py | Validates 2-field use_attention_sink format and updates error messaging. |
| examples/models/llama/source_transformation/attention_sink.py | Replaces eviction approach with ring-buffer KV cache + sink-preserving cache index manager; removes forward monkey-patch. |
| examples/models/llama/attention.py | Updates AttentionMHA.forward to treat ring-buffer caches specially (mask computed after KV update; skip bounds check). |
| examples/models/llama/model.py | Parses 2-field attention sink config and relaxes RoPE max-context constraint to >= sink_size + window_size. |
| examples/models/llama/source_transformation/test_attention_sink.py | Rewrites tests to cover ring-buffer sink preservation, wrapping, and masking behaviors. |
| examples/models/llama/config/test_llm_config.py | Updates config validation tests for the new 2-field format. |
| examples/models/llama/config/llama_attention_sink.yaml | Adds example configuration for attention sink with ring-buffer sizing guidance. |
| examples/models/llama/BUCK | Adjusts attention_sink_test deps/preloads to support the rewritten tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) | ||
| delta = pos_q - cache_positions | ||
|
|
||
| # Valid if position is filled (>= 0) and causal (delta >= 0) | ||
| is_valid = (cache_positions >= 0) & (delta >= 0) |
There was a problem hiding this comment.
_create_causal_mask_for_attention_sink builds pos_q via torch.arange() on the default device. If cache_positions is moved to CUDA (e.g., model.to('cuda')), this will raise a device-mismatch error when computing delta = pos_q - cache_positions. Create pos_q (and any scalar constants used in torch.where) on cache_positions.device to keep masking device-agnostic.
| # Sink tokens go to fixed slots; window tokens use ring buffer | ||
| indices = torch.where( | ||
| orig_indices < self.sink_size, | ||
| orig_indices, | ||
| self.sink_size + (orig_indices - self.sink_size) % self.ring_size, | ||
| ) |
There was a problem hiding this comment.
The ring-buffer index expression computes (orig_indices - sink_size) % self.ring_size unconditionally (torch.where does not short-circuit). If window_size=0 (so ring_size==0), this will raise a modulo-by-zero error even when all orig_indices are sink tokens. Add an explicit guard for ring_size==0 (either disallow window_size=0 or handle the sink-only case without modulo).
| # Sink tokens go to fixed slots; window tokens use ring buffer | |
| indices = torch.where( | |
| orig_indices < self.sink_size, | |
| orig_indices, | |
| self.sink_size + (orig_indices - self.sink_size) % self.ring_size, | |
| ) | |
| # torch.where does not short-circuit, so guard the sink-only case to | |
| # avoid evaluating modulo by zero when ring_size == 0. | |
| if self.ring_size == 0: | |
| torch._check( | |
| bool((orig_indices < self.sink_size).all().item()), | |
| "Positions beyond sink_size are invalid when ring_size is 0", | |
| ) | |
| indices = orig_indices | |
| else: | |
| # Sink tokens go to fixed slots; window tokens use ring buffer | |
| indices = torch.where( | |
| orig_indices < self.sink_size, | |
| orig_indices, | |
| self.sink_size + (orig_indices - self.sink_size) % self.ring_size, | |
| ) |
| # Update cache_positions exactly like original CachePositionsManager | ||
| full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) | ||
| arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) | ||
| cache_positions = torch.where( | ||
| arange_tensor < start_pos, self.cache_positions, full_t | ||
| ) |
There was a problem hiding this comment.
calculate_positions_and_update_indices constructs full_t and arange_tensor on CPU by default; if this module is moved to another device, torch.where(...) will fail due to mixed-device inputs. Allocate these tensors on self.cache_positions.device (and similarly ensure orig_indices/indices are on a consistent device) to support model.to(device).
| # Use torch._check for export compatibility (data-dependent guard) | ||
| torch._check(input_pos[0].item() + seq_len <= self.max_context_length) |
There was a problem hiding this comment.
RopeWithAttentionSink.get_freqs drops Rope.get_freqs’ torch._check_is_size(input_pos_item) guard and uses input_pos[0] instead of input_pos[-1]. This can allow negative positions (or multi-element input_pos) to slip through and fail later during narrow/indexing. Consider mirroring Rope.get_freqs’ size check and using the same element convention as the base class.
| # Use torch._check for export compatibility (data-dependent guard) | |
| torch._check(input_pos[0].item() + seq_len <= self.max_context_length) | |
| input_pos_item = input_pos[-1].item() | |
| torch._check_is_size(input_pos_item) | |
| # Use torch._check for export compatibility (data-dependent guard) | |
| torch._check(input_pos_item + seq_len <= self.max_context_length) |
| attention_sink_params = self.llm_config.model.use_attention_sink.split(",") | ||
| assert len(attention_sink_params) == 3 | ||
| assert len(attention_sink_params) == 2, ( | ||
| f"use_attention_sink expects exactly 2 comma-separated values " | ||
| f"(sink_size,window_size), got {len(attention_sink_params)}" | ||
| ) |
There was a problem hiding this comment.
This PR changes use_attention_sink from 3 parameters to 2, but other call sites still assume 3 (e.g., examples/models/llama/eval_llama_lib.py asserts len==3 around line ~350, and examples/models/llama/export_llama_lib.py’s CLI help still documents 3 values around line ~594). Please update those to avoid runtime assertion failures / misleading CLI docs.
| def _validate_attention_sink(self): | ||
| if self.use_attention_sink: | ||
| attention_sink_params = self.use_attention_sink.split(",") | ||
| if len(attention_sink_params) != 3: | ||
| if len(attention_sink_params) != 2: | ||
| raise ValueError( | ||
| "The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'" | ||
| "The value of use_attention_sink must be structured like '<sink_size>,<window_size>'" | ||
| ) |
There was a problem hiding this comment.
ModelConfig’s docstring above still describes use_attention_sink as '<sink_size>,<window_size>,<batch_eviction_size>' (and gives a 3-value example), but validation now enforces exactly 2 values. Please update the documentation string/comments to match the new 2-field format to prevent confusion.
Summary: Replace the eviction-based attention sink implementation with a torch.export compatible ring buffer approach, and rewrite all tests. Key changes: - RopeWithAttentionSink: simplified to pass through original positions (no more position shifting or k re-rotation) - KVCacheWithAttentionSink: uses ring buffer with index_copy_ instead of dynamic eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer]. Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively. - CachePositionsManagerWithSink: new module that maps positions to cache indices, with sink tokens in fixed slots and window tokens in ring buffer region. - AttentionMHA.forward: ring buffer models skip start_pos bounds check and compute their own causal mask after KV cache update. - Remove eviction_batch_size from all interfaces (no longer needed). - Remove attention_sink_forward monkey-patch and rerotate_k dead code. - Add llama_attention_sink.yaml example config. - Rewrite 16 eviction-based tests with 18 ring buffer tests covering sink preservation, ring wrapping, causal masking, and degenerate cases. Differential Revision: D100216687
8355623 to
973c88b
Compare
Summary:
Replace the eviction-based attention sink implementation with a torch.export
compatible ring buffer approach, and rewrite all tests.
Key changes:
position shifting or k re-rotation)
eviction (torch.cat/narrow/shift). Cache layout: [sink slots | ring buffer].
Sets is_ring_buffer=True so AttentionMHA.forward handles masking natively.
with sink tokens in fixed slots and window tokens in ring buffer region.
their own causal mask after KV cache update.
preservation, ring wrapping, causal masking, and degenerate cases.
Differential Revision: D100216687