Skip to content

feat(qwen3_5): Cross-turn image KV cache#828

Closed
AirRunner wants to merge 4 commits intoBlaizzy:mainfrom
AirRunner:feat/image-end-index
Closed

feat(qwen3_5): Cross-turn image KV cache#828
AirRunner wants to merge 4 commits intoBlaizzy:mainfrom
AirRunner:feat/image-end-index

Conversation

@AirRunner
Copy link

This PR adds two small additions to qwen3_5 that are required by a cross-turn image KV cache implementation in mlx-engine (tracking issue lmstudio-ai/mlx-engine#287).

Currently every conversation turn re-runs the vision tower and re-prefills the full context from scratch, even when the same image was already processed in the previous turn. The mlx-engine fix saves a KV cache checkpoint right after the image tokens and restores it on subsequent turns, but it needs two pieces of information from mlx-vlm that were not previously exposed:

  1. Where does the image block end? (image_end_index): so the engine knows where to split the saved checkpoint from the text suffix to prefill.
  2. Can the vision tower be run for a subset of images? (get_partial_input_embeddings): so that when a new image is added to a conversation, only that image goes through the vision tower; the KV state for earlier images is reused as-is.

Changes

mlx_vlm/models/base.py

Adds an optional image_end_index: int | None field to the InputEmbeddingsFeatures dataclass. The field is excluded from to_dict() (it is engine metadata, not a model kwarg).

mlx_vlm/models/qwen3_5/qwen3_5.py

get_input_embeddings computes and returns image_end_index: the position of the first non-visual token after the last image/video token block. Uses the existing arange * mask pattern already present in the file.

get_partial_input_embeddings: like get_input_embeddings but runs the vision tower only for images[partial_depth:]:

  1. Slices pixel_values[n_cached_patches:] and grid_thw[partial_depth:] to process new images only.
  2. Runs the vision tower on the slice.
  3. Gets text embeddings for the full sequence via embed_tokens.
  4. Finds the start of the first new image block by scanning input_ids.
  5. Overwrites only the new image token positions using masked_scatter.
  6. Sets up multi-modal RoPE position IDs (get_rope_index) for the suffix prefill.

Returns inputs_embeds ready for chunked prefill from the end of the last cached image block. The method is on qwen3_5.Model and inherited by qwen3_5_moe.Model automatically.

Tests

tests/test_image_end_index.py: 15 tests in two unittest.TestCase classes (all pass):

  • TestImageEndIndex (8 tests): image_end_index boundary computation — tokens at start/middle/end, single token, video tokens, mixed, no image, realistic layout.
  • TestNewImgStart (7 tests): partial_depth boundary used by get_partial_input_embeddings — depths 0/1/2, out-of-range, video tokens, realistic multi-image layout.

End-to-end validation

Tested on Qwen3.5-35B-A3B (MoE, 5-bit) via LM Studio with the mlx-engine integration:

[kv-image] partial hit depth=1/2
[kv-image] checkpoint saved depth=2 index=20719
[kv-image] cache hit depth=2

Turn with one new image added: vision tower runs only for the new image (~15s instead of ~40s for full context). Next turn: full cache hit, vision tower skipped entirely (~1s).

Related

Add image_end_index: Optional[int] = None to InputEmbeddingsFeatures in base.py (excluded from to_dict()).

Compute and return it from qwen3_5.get_input_embeddings: last visual token position + 1, derived from input_ids before merging. Covers both image and video tokens; returns None for text-only calls.
…cache

Implement the model-specific side of partial image KV caching: run the vision tower only for images[partial_depth:], merge their features into the full-sequence text embeddings using masked_scatter, and set up multi-modal RoPE position IDs for the suffix prefill.

Called by VisionModelWrapper.prefill_with_partial_cache() in mlx-engine when a new image is appended to a conversation whose earlier images are already in the KV cache.
…dings

Extends test_image_end_index.py with 7 new test cases covering the new_img_start boundary computation used by get_partial_input_embeddings: partial_depth 0/1/2, out-of-range depth, video tokens, and a realistic multi-image prompt layout.
@AirRunner AirRunner changed the title feat(qwen3_5): image_end_index + get_partial_input_embeddings for cross-turn image KV cache feat(qwen3_5): Cross-turn image KV cache Mar 17, 2026
@AirRunner
Copy link
Author

Two more things worth addressing in the future:

1. image_end_index for other models

The image_end_index field is now in InputEmbeddingsFeatures (base class), but only qwen3_5 computes it. Other models return None implicitly. Any engine-level KV cache that relies on this boundary would need each model to implement it. A follow-up could add it to the other get_input_embeddings implementations.

2. Vision tower VRAM spike

get_input_embeddings currently runs the vision tower in a single forward pass. On long contexts with image patches, this causes a significant VRAM spike before the prefill even starts. The mlx-engine integration works around this by chunking the prefill after the vision tower, but there is still a spike. This might even causes outright crashes on large contexts or images (see #79).

It would be worth exploring chunking the vision tower pass (patch batch by patch batch) or returning image embeddings one block at a time so the caller can interleave mx.eval + mx.clear_cache() between blocks.

@Blaizzy
Copy link
Owner

Blaizzy commented Mar 17, 2026

Hey @AirRunner

Thanks for the contribution!

But lets start from an issue with reproducible examples, then we will investigate and agree on a solution after which you are free to send a PR.

@Blaizzy Blaizzy closed this Mar 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants