Skip to content

Cross-turn image KV cache with Qwen3.5: vision tower re-runs on every turn #832

@AirRunner

Description

@AirRunner

This proposes two small additions to qwen3_5 that are required by a cross-turn image KV cache implementation in tracking issue lmstudio-ai/mlx-engine#287.

Currently every conversation turn re-runs the vision tower and re-prefills the full context from scratch, even when the same image was already processed in the previous turn. The mlx-engine fix saves a KV cache checkpoint right after the image tokens and restores it on subsequent turns, but it needs two pieces of information from mlx-vlm that were not previously exposed:

  1. Where does the image block end? (image_end_index): so the engine knows where to split the saved checkpoint from the text suffix to prefill.
  2. Can the vision tower be run for a subset of images? (get_partial_input_embeddings): so that when a new image is added to a conversation, only that image goes through the vision tower; the KV state for earlier images is reused as-is.

A working implementation is available at AirRunner/mlx-vlm, branch feat/image-end-index.

Proposed changes

1. mlx_vlm/models/base.py

Adds an optional image_end_index: int | None field to the InputEmbeddingsFeatures dataclass. The field is excluded from to_dict() (it is engine metadata, not a model kwarg).

2. mlx_vlm/models/qwen3_5/qwen3_5.py

get_input_embeddings computes and returns image_end_index: the position of the first non-visual token after the last image/video token block. Uses the existing arange * mask pattern already present in the file.

get_partial_input_embeddings: like get_input_embeddings but runs the vision tower only for images[partial_depth:]:

  1. Slices pixel_values[n_cached_patches:] and grid_thw[partial_depth:] to process new images only.
  2. Runs the vision tower on the slice.
  3. Gets text embeddings for the full sequence via embed_tokens.
  4. Finds the start of the first new image block by scanning input_ids.
  5. Overwrites only the new image token positions using masked_scatter.
  6. Sets up multi-modal RoPE position IDs (get_rope_index) for the suffix prefill.

Returns inputs_embeds ready for chunked prefill from the end of the last cached image block. The method is on qwen3_5.Model and inherited by qwen3_5_moe.Model automatically.

Tests

tests/test_image_end_index.py: 15 tests in two unittest.TestCase classes (all pass):

  • TestImageEndIndex (8 tests): image_end_index boundary computation — tokens at start/middle/end, single token, video tokens, mixed, no image, realistic layout.
  • TestNewImgStart (7 tests): partial_depth boundary used by get_partial_input_embeddings — depths 0/1/2, out-of-range, video tokens, realistic multi-image layout.

End-to-end validation

Tested on Qwen3.5-35B-A3B (MoE, 5-bit) via LM Studio with the mlx-engine integration.

Scenario: multi-turn conversation where images are added progressively.

  • Turn 1 — first image sent (e.g. a screenshot). Full prefill (~22s). KV checkpoint saved at depth=1 (after the image tokens).
  • Turns 2–N — text-only follow-up messages. Cache hits on the text prefix, only the new suffix is prefilled.
  • Turn N+1 — a second image is added. Without the fix: the entire conversation is re-prefilled from scratch (~30s). With the fix:
[kv-image] partial hit depth=1/2
[kv-image] checkpoint saved depth=2 index=20719

The checkpoint at depth=1 is restored; the vision tower runs only for the new image. A new checkpoint is saved at depth=2.

  • Turn N+2 — text-only message.
[kv-image] cache hit depth=2

Both images are recognised, vision tower skipped entirely (~1s).

Related

Follow-up topics

image_end_index for other models

The image_end_index field is in InputEmbeddingsFeatures (base class), but only qwen3_5 computes it. Other models return None implicitly. Any engine-level KV cache that relies on this boundary would need each model to implement it. A follow-up could add it to the other get_input_embeddings implementations.

Vision tower VRAM spike

get_input_embeddings currently runs the vision tower in a single forward pass. On long contexts with image patches, this causes a significant VRAM spike before the prefill even starts. The mlx-engine integration works around this by chunking the prefill after the vision tower, but there is still a spike. This might even cause outright crashes on large contexts or images (see #79).

It would be worth exploring chunking the vision tower pass (patch batch by patch batch) or returning image embeddings one block at a time so the caller can interleave mx.eval + mx.clear_cache() between blocks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions