[draft patch] implement zero-copy KV cache trim via set_shape() for indirect KV cache states#35721
Open
xzhan34 wants to merge 2 commits intoopenvinotoolkit:masterfrom
Open
[draft patch] implement zero-copy KV cache trim via set_shape() for indirect KV cache states#35721xzhan34 wants to merge 2 commits intoopenvinotoolkit:masterfrom
xzhan34 wants to merge 2 commits intoopenvinotoolkit:masterfrom
Conversation
accb5ea to
282f3bf
Compare
wangleis
approved these changes
May 8, 2026
… KV cache states Combined port of three commits from openvino.mx/thirdparty/openvino: - 56813a3 [GPU] Add get_shape/set_shape overrides to VariableStateIndirectKVCache - 9985818 fix(GPU): implement set_shape/get_state for compressed KV cache VariableState - 4f81622 fix(GPU): use padding-aware zero-copy KV cache trim in set_shape() Problem ------- VariableStateIndirectKVCache and its compressed variant inherit from MultiTensorState -> VariableStateBase -> IVariableState. The base class IVariableState::get_shape() and set_shape() throw ov::NotImplemented. This means any caller that tries to query or resize the KV cache shape on GPU variable states (e.g. speculative decoding KV trim) hits an unimplemented exception, forcing an expensive CPU-side copy fallback that costs ~178 ms per verify step. Additionally, even when set_shape() was partially added (commit 1), it performed a naive buffer resize without adjusting GPU buffer padding or the beam table state, causing: - SDPA kernel reading stale data from misaligned buffer offsets - Progressive text degeneration after ~100 trim cycles (repetitive 'and, and, ...' or '. . . .' output patterns) For compressed KV cache (kv_cache_precision=u8), set_shape() also needs to resize compression scale/zero-point tensors, and get_state() threw unconditionally. Changes ------- multi_tensor_variable_state.hpp: - Add get_shape() and set_shape() overrides to VariableStateIndirectKVCache - Add set_shape() override to VariableStateIndirectKVCacheCompressed - Add get_concat_axis() accessor for derived class use - Change m_beam_axis/m_concat_axis from private to protected multi_tensor_variable_state.cpp: - VariableStateIndirectKVCache::get_shape(): delegate to m_hidden_states[0] - VariableStateIndirectKVCache::set_shape(): padding-aware zero-copy trim that adjusts data_padding._upper_size on concat axis for both KV data (m_hidden_states[0]) and beam table (m_hidden_states[1]) — metadata only, zero GPU data movement - VariableStateIndirectKVCacheCompressed::set_shape(): delegates KV + beam table to base class, then applies same padding-based trim to compression scale (m_hidden_states[2]) and optional zero-point (m_hidden_states[3]) tensors on their sequence axis How it works ------------ GPU SDPA uses padded buffer layouts where allocated_size = used_size + upper_padding. To trim the KV cache from seq_len=N to seq_len=M (M<N): 1. Set new partial shape to M on the concat axis 2. Increase upper_padding by (N - M) to keep total allocated size constant 3. No GPU memory allocation or data copy needed — only layout metadata changes This preserves per-head strides in the GPU buffer so SDPA kernels continue to access data at correct offsets. Performance impact (Qwen3-Omni 4B DFlash speculative decoding, GPU) ------------------------------------------------------------------- KV trim: 178 ms/call (CPU copy) -> 0.22 ms/call (zero-copy) [823x faster] TPOT: 187 ms -> 108 ms [42% faster] Throughput: 5.4 -> 9.4 tokens/s [74% higher] Text quality: no degeneration after 100+ trim cycles Usage ----- This fix is transparent to users. It activates automatically when: 1. Using GPU plugin with indirect KV cache (default for LLM inference) 2. Calling VariableState::set_shape() to trim KV cache (e.g. speculative decoding verify-and-trim, or any KV cache management that resizes states) 3. Works with both FP16 and compressed (u8) KV cache precision modes No API changes, no new configuration flags. Existing code that calls set_shape() on GPU variable states will automatically use the fast padding-based path instead of falling back to CPU copy.
Tests cover: - get_shape() returns correct initial shape - set_shape() zero-copy trim adjusts layout without data movement - Padding adjustment preserves padded dims (upper_size increases) - Beam table state is trimmed in sync with KV state - Multiple consecutive trims accumulate padding correctly - get_concat_axis() returns configured axis - Compressed variant trims KV, scale, and optional zero-point states - set_state()/get_state() on compressed variant throws (unsupported)
282f3bf to
9d2a677
Compare
rkazants
requested changes
May 8, 2026
Collaborator
rkazants
left a comment
There was a problem hiding this comment.
Please provide proper PR description with small code snippet to reproduce the problem. And share JIRA ticket
Can you please provide a prompt that you used for this code generation.
Thanks!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Combined port of three commits from openvino.mx/thirdparty/openvino:
Problem
VariableStateIndirectKVCache and its compressed variant inherit from MultiTensorState -> VariableStateBase -> IVariableState. The base class IVariableState::get_shape() and set_shape() throw ov::NotImplemented. This means any caller that tries to query or resize the KV cache shape on GPU variable states (e.g. speculative decoding KV trim) hits an unimplemented exception, forcing an expensive CPU-side copy fallback that costs ~178 ms per verify step.
Additionally, even when set_shape() was partially added (commit 1), it performed a naive buffer resize without adjusting GPU buffer padding or the beam table state, causing:
For compressed KV cache (kv_cache_precision=u8), set_shape() also needs to resize compression scale/zero-point tensors, and get_state() threw unconditionally.
Changes
multi_tensor_variable_state.hpp:
multi_tensor_variable_state.cpp:
How it works
GPU SDPA uses padded buffer layouts where allocated_size = used_size + upper_padding. To trim the KV cache from seq_len=N to seq_len=M (M<N):
Performance impact (Qwen3-Omni 4B DFlash speculative decoding, GPU) -------------------------------------------------------------------
KV trim: 178 ms/call (CPU copy) -> 0.22 ms/call (zero-copy) [823x faster]
TPOT: 187 ms -> 108 ms [42% faster]
Throughput: 5.4 -> 9.4 tokens/s [74% higher]
Text quality: no degeneration after 100+ trim cycles
Usage
This fix is transparent to users. It activates automatically when:
No API changes, no new configuration flags. Existing code that calls set_shape() on GPU variable states will automatically use the fast padding-based path instead of falling back to CPU copy.
Details:
Tickets:
AI Assistance: