-
-
Notifications
You must be signed in to change notification settings - Fork 306
Description
This proposes two small additions to qwen3_5 that are required by a cross-turn image KV cache implementation in tracking issue lmstudio-ai/mlx-engine#287.
Currently every conversation turn re-runs the vision tower and re-prefills the full context from scratch, even when the same image was already processed in the previous turn. The mlx-engine fix saves a KV cache checkpoint right after the image tokens and restores it on subsequent turns, but it needs two pieces of information from mlx-vlm that were not previously exposed:
- Where does the image block end? (
image_end_index): so the engine knows where to split the saved checkpoint from the text suffix to prefill. - Can the vision tower be run for a subset of images? (
get_partial_input_embeddings): so that when a new image is added to a conversation, only that image goes through the vision tower; the KV state for earlier images is reused as-is.
A working implementation is available at AirRunner/mlx-vlm, branch feat/image-end-index.
Proposed changes
1. mlx_vlm/models/base.py
Adds an optional image_end_index: int | None field to the InputEmbeddingsFeatures dataclass. The field is excluded from to_dict() (it is engine metadata, not a model kwarg).
2. mlx_vlm/models/qwen3_5/qwen3_5.py
get_input_embeddings computes and returns image_end_index: the position of the first non-visual token after the last image/video token block. Uses the existing arange * mask pattern already present in the file.
get_partial_input_embeddings: like get_input_embeddings but runs the vision tower only for images[partial_depth:]:
- Slices
pixel_values[n_cached_patches:]andgrid_thw[partial_depth:]to process new images only. - Runs the vision tower on the slice.
- Gets text embeddings for the full sequence via
embed_tokens. - Finds the start of the first new image block by scanning
input_ids. - Overwrites only the new image token positions using
masked_scatter. - Sets up multi-modal RoPE position IDs (
get_rope_index) for the suffix prefill.
Returns inputs_embeds ready for chunked prefill from the end of the last cached image block. The method is on qwen3_5.Model and inherited by qwen3_5_moe.Model automatically.
Tests
tests/test_image_end_index.py: 15 tests in two unittest.TestCase classes (all pass):
TestImageEndIndex(8 tests):image_end_indexboundary computation — tokens at start/middle/end, single token, video tokens, mixed, no image, realistic layout.TestNewImgStart(7 tests):partial_depthboundary used byget_partial_input_embeddings— depths 0/1/2, out-of-range, video tokens, realistic multi-image layout.
End-to-end validation
Tested on Qwen3.5-35B-A3B (MoE, 5-bit) via LM Studio with the mlx-engine integration.
Scenario: multi-turn conversation where images are added progressively.
- Turn 1 — first image sent (e.g. a screenshot). Full prefill (~22s). KV checkpoint saved at
depth=1(after the image tokens). - Turns 2–N — text-only follow-up messages. Cache hits on the text prefix, only the new suffix is prefilled.
- Turn N+1 — a second image is added. Without the fix: the entire conversation is re-prefilled from scratch (~30s). With the fix:
[kv-image] partial hit depth=1/2
[kv-image] checkpoint saved depth=2 index=20719
The checkpoint at depth=1 is restored; the vision tower runs only for the new image. A new checkpoint is saved at depth=2.
- Turn N+2 — text-only message.
[kv-image] cache hit depth=2
Both images are recognised, vision tower skipped entirely (~1s).
Related
- mlx-engine tracking issue: lmstudio-ai/mlx-engine#287
- mlx-vlm issue on the same root cause (closed): how to reuse pixel_values? #755
Follow-up topics
image_end_index for other models
The image_end_index field is in InputEmbeddingsFeatures (base class), but only qwen3_5 computes it. Other models return None implicitly. Any engine-level KV cache that relies on this boundary would need each model to implement it. A follow-up could add it to the other get_input_embeddings implementations.
Vision tower VRAM spike
get_input_embeddings currently runs the vision tower in a single forward pass. On long contexts with image patches, this causes a significant VRAM spike before the prefill even starts. The mlx-engine integration works around this by chunking the prefill after the vision tower, but there is still a spike. This might even cause outright crashes on large contexts or images (see #79).
It would be worth exploring chunking the vision tower pass (patch batch by patch batch) or returning image embeddings one block at a time so the caller can interleave mx.eval + mx.clear_cache() between blocks.