feat(qwen3_5): Cross-turn image KV cache#828
Conversation
Add image_end_index: Optional[int] = None to InputEmbeddingsFeatures in base.py (excluded from to_dict()). Compute and return it from qwen3_5.get_input_embeddings: last visual token position + 1, derived from input_ids before merging. Covers both image and video tokens; returns None for text-only calls.
…cache Implement the model-specific side of partial image KV caching: run the vision tower only for images[partial_depth:], merge their features into the full-sequence text embeddings using masked_scatter, and set up multi-modal RoPE position IDs for the suffix prefill. Called by VisionModelWrapper.prefill_with_partial_cache() in mlx-engine when a new image is appended to a conversation whose earlier images are already in the KV cache.
…dings Extends test_image_end_index.py with 7 new test cases covering the new_img_start boundary computation used by get_partial_input_embeddings: partial_depth 0/1/2, out-of-range depth, video tokens, and a realistic multi-image prompt layout.
image_end_index + get_partial_input_embeddings for cross-turn image KV cache|
Two more things worth addressing in the future: 1. The 2. Vision tower VRAM spike
It would be worth exploring chunking the vision tower pass (patch batch by patch batch) or returning image embeddings one block at a time so the caller can interleave |
|
Hey @AirRunner Thanks for the contribution! But lets start from an issue with reproducible examples, then we will investigate and agree on a solution after which you are free to send a PR. |
This PR adds two small additions to
qwen3_5that are required by a cross-turn image KV cache implementation in mlx-engine (tracking issue lmstudio-ai/mlx-engine#287).Currently every conversation turn re-runs the vision tower and re-prefills the full context from scratch, even when the same image was already processed in the previous turn. The mlx-engine fix saves a KV cache checkpoint right after the image tokens and restores it on subsequent turns, but it needs two pieces of information from mlx-vlm that were not previously exposed:
image_end_index): so the engine knows where to split the saved checkpoint from the text suffix to prefill.get_partial_input_embeddings): so that when a new image is added to a conversation, only that image goes through the vision tower; the KV state for earlier images is reused as-is.Changes
mlx_vlm/models/base.pyAdds an optional
image_end_index: int | Nonefield to theInputEmbeddingsFeaturesdataclass. The field is excluded fromto_dict()(it is engine metadata, not a model kwarg).mlx_vlm/models/qwen3_5/qwen3_5.pyget_input_embeddingscomputes and returnsimage_end_index: the position of the first non-visual token after the last image/video token block. Uses the existingarange * maskpattern already present in the file.get_partial_input_embeddings: likeget_input_embeddingsbut runs the vision tower only forimages[partial_depth:]:pixel_values[n_cached_patches:]andgrid_thw[partial_depth:]to process new images only.embed_tokens.input_ids.masked_scatter.get_rope_index) for the suffix prefill.Returns
inputs_embedsready for chunked prefill from the end of the last cached image block. The method is onqwen3_5.Modeland inherited byqwen3_5_moe.Modelautomatically.Tests
tests/test_image_end_index.py: 15 tests in twounittest.TestCaseclasses (all pass):TestImageEndIndex(8 tests):image_end_indexboundary computation — tokens at start/middle/end, single token, video tokens, mixed, no image, realistic layout.TestNewImgStart(7 tests):partial_depthboundary used byget_partial_input_embeddings— depths 0/1/2, out-of-range, video tokens, realistic multi-image layout.End-to-end validation
Tested on Qwen3.5-35B-A3B (MoE, 5-bit) via LM Studio with the mlx-engine integration:
Turn with one new image added: vision tower runs only for the new image (~15s instead of ~40s for full context). Next turn: full cache hit, vision tower skipped entirely (~1s).
Related