Skip to content

Improve mlx-vlm type hints for generate() / stream_generate() so Pyright/Pylance can catch API mistakes #825

@jrp2014

Description

@jrp2014

When running my various integration test, using mlx-vlm with VS Code + Pylance/Pyright, I found that the current type surface is not as precise as it could be for the checker to be useful, according to my bot.

There are two practical problems:

  1. False positives:
    stream_generate() yields GenerationResult objects at runtime, but generated typing can make it look like it yields str.

  2. Missed diagnostics:
    generate() and stream_generate() accept many supported kwargs via **kwargs, so the type checker cannot validate misspelled or unsupported arguments, and editor completion is much less helpful than it could be.

Why this matters

Today, code like this is valid at runtime:

for response in stream_generate(model, processor, prompt, image=image):
    print(response.text)

But if the published/generated typing says str | Generator[str, None, None], Pylance/Pyright can infer response as str, which produces noise.

At the same time, code like this:

generate(
    model=model,
    processor=processor,
    prompt=prompt,
    image=image,
    top_p=0.95,
    min_p=0.05,
    top_k=40,
    prefill_step_size=2048,
)

works in practice, but the checker cannot do much with it because most of the public generation API is hidden behind **kwargs.

Runtime behavior observed

In mlx_vlm/generate.py, stream_generate() yields GenerationResult objects with fields such as:

  • text
  • token
  • logprobs
  • prompt_tokens
  • generation_tokens
  • total_tokens
  • prompt_tps
  • generation_tps
  • peak_memory

So the effective return type appears to be:

Iterator[GenerationResult]

Also, several generated signatures use = None defaults without making the annotation optional, for example:

images: str | list[str] = None
audios: str | list[str] = None
prompts: list[str] = None

which should be:

images: str | list[str] | None = None
audios: str | list[str] | None = None
prompts: list[str] | None = None

Suggested changes

1. Make stream_generate() return type precise

Current effective problem:

def stream_generate(
    model: nn.Module,
    processor: PreTrainedTokenizer,
    prompt: str,
    image: str | list[str] | None = None,
    audio: str | list[str] | None = None,
    **kwargs,
) -> str | Generator[str, None, None]: ...

Suggested:

from collections.abc import Iterator

def stream_generate(
    model: nn.Module,
    processor: ProcessorLike,
    prompt: str,
    image: str | list[str] | None = None,
    audio: str | list[str] | None = None,
    **kwargs,
) -> Iterator[GenerationResult]: ...

If you prefer to keep Generator, this would also be fine:

def stream_generate(
    model: nn.Module,
    processor: ProcessorLike,
    prompt: str,
    image: str | list[str] | None = None,
    audio: str | list[str] | None = None,
    **kwargs,
) -> Generator[GenerationResult, None, None]: ...

2. Fix optional defaults throughout public signatures

Suggested pattern:

def batch_generate(
    model,
    processor,
    images: str | list[str] | None = None,
    audios: str | list[str] | None = None,
    prompts: list[str] | None = None,
    max_tokens: int | list[int] = 128,
    verbose: bool = False,
    group_by_shape: bool = True,
    track_image_sizes: bool = True,
    **kwargs,
): ...

3. Expose supported kwargs directly on generate() and stream_generate()

Right now, many supported options live only in generate_step() and are forwarded through **kwargs. That prevents static checkers from catching real mistakes.

A more checkable public signature would look something like:

def generate(
    model: nn.Module,
    processor: ProcessorLike,
    prompt: str,
    image: str | list[str] | None = None,
    audio: str | list[str] | None = None,
    *,
    verbose: bool = False,
    max_tokens: int = 256,
    temperature: float = 0.0,
    repetition_penalty: float | None = None,
    repetition_context_size: int | None = 20,
    top_p: float = 1.0,
    min_p: float = 0.0,
    top_k: int = 0,
    logit_bias: dict[int, float] | None = None,
    prompt_cache: list[Any] | None = None,
    max_kv_size: int | None = None,
    kv_bits: int | None = None,
    kv_group_size: int = 64,
    quantized_kv_start: int = 0,
    sampler: Callable[[mx.array], mx.array] | None = None,
    logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
    prefill_step_size: int | None = None,
    skip_special_tokens: bool = False,
    resize_shape: tuple[int, int] | None = None,
    eos_tokens: list[int] | list[str] | None = None,
    thinking_budget: int | None = None,
    thinking_end_token: str = "</think>",
    thinking_start_token: str | None = None,
    enable_thinking: bool = False,
) -> GenerationResult: ...

And similarly:

def stream_generate(
    model: nn.Module,
    processor: ProcessorLike,
    prompt: str,
    image: str | list[str] | None = None,
    audio: str | list[str] | None = None,
    *,
    max_tokens: int = 256,
    temperature: float = 0.0,
    repetition_penalty: float | None = None,
    repetition_context_size: int | None = 20,
    top_p: float = 1.0,
    min_p: float = 0.0,
    top_k: int = 0,
    logit_bias: dict[int, float] | None = None,
    prompt_cache: list[Any] | None = None,
    max_kv_size: int | None = None,
    kv_bits: int | None = None,
    kv_group_size: int = 64,
    quantized_kv_start: int = 0,
    sampler: Callable[[mx.array], mx.array] | None = None,
    logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
    prefill_step_size: int | None = None,
    skip_special_tokens: bool = False,
    resize_shape: tuple[int, int] | None = None,
    eos_tokens: list[int] | list[str] | None = None,
    thinking_budget: int | None = None,
    thinking_end_token: str = "</think>",
    thinking_start_token: str | None = None,
    enable_thinking: bool = False,
) -> Iterator[GenerationResult]: ...

4. Broaden or protocol-ize the processor type

The runtime contract seems broader than PreTrainedTokenizer. In practice the object behaves more like an AutoProcessor-like wrapper and may expose things such as:

  • tokenizer
  • detokenizer
  • possibly chat_template

A protocol would probably describe the real contract better than a single tokenizer class.

For example:

from typing import Protocol, runtime_checkable, Any

@runtime_checkable
class ProcessorLike(Protocol):
    tokenizer: Any
    detokenizer: Any

Even a minimal protocol would be more accurate than a narrow tokenizer annotation if the runtime accepts multiple processor shapes.

Expected benefit

This would make Pylance/Pyright and other type checkers much more useful by:

  • avoiding false positives around stream_generate()
  • improving autocomplete for supported kwargs
  • catching misspelled generation parameters
  • helping downstream users avoid maintaining local stub patches

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions