-
-
Notifications
You must be signed in to change notification settings - Fork 306
Description
When running my various integration test, using mlx-vlm with VS Code + Pylance/Pyright, I found that the current type surface is not as precise as it could be for the checker to be useful, according to my bot.
There are two practical problems:
-
False positives:
stream_generate()yieldsGenerationResultobjects at runtime, but generated typing can make it look like it yieldsstr. -
Missed diagnostics:
generate()andstream_generate()accept many supported kwargs via**kwargs, so the type checker cannot validate misspelled or unsupported arguments, and editor completion is much less helpful than it could be.
Why this matters
Today, code like this is valid at runtime:
for response in stream_generate(model, processor, prompt, image=image):
print(response.text)But if the published/generated typing says str | Generator[str, None, None], Pylance/Pyright can infer response as str, which produces noise.
At the same time, code like this:
generate(
model=model,
processor=processor,
prompt=prompt,
image=image,
top_p=0.95,
min_p=0.05,
top_k=40,
prefill_step_size=2048,
)works in practice, but the checker cannot do much with it because most of the public generation API is hidden behind **kwargs.
Runtime behavior observed
In mlx_vlm/generate.py, stream_generate() yields GenerationResult objects with fields such as:
texttokenlogprobsprompt_tokensgeneration_tokenstotal_tokensprompt_tpsgeneration_tpspeak_memory
So the effective return type appears to be:
Iterator[GenerationResult]Also, several generated signatures use = None defaults without making the annotation optional, for example:
images: str | list[str] = None
audios: str | list[str] = None
prompts: list[str] = Nonewhich should be:
images: str | list[str] | None = None
audios: str | list[str] | None = None
prompts: list[str] | None = NoneSuggested changes
1. Make stream_generate() return type precise
Current effective problem:
def stream_generate(
model: nn.Module,
processor: PreTrainedTokenizer,
prompt: str,
image: str | list[str] | None = None,
audio: str | list[str] | None = None,
**kwargs,
) -> str | Generator[str, None, None]: ...Suggested:
from collections.abc import Iterator
def stream_generate(
model: nn.Module,
processor: ProcessorLike,
prompt: str,
image: str | list[str] | None = None,
audio: str | list[str] | None = None,
**kwargs,
) -> Iterator[GenerationResult]: ...If you prefer to keep Generator, this would also be fine:
def stream_generate(
model: nn.Module,
processor: ProcessorLike,
prompt: str,
image: str | list[str] | None = None,
audio: str | list[str] | None = None,
**kwargs,
) -> Generator[GenerationResult, None, None]: ...2. Fix optional defaults throughout public signatures
Suggested pattern:
def batch_generate(
model,
processor,
images: str | list[str] | None = None,
audios: str | list[str] | None = None,
prompts: list[str] | None = None,
max_tokens: int | list[int] = 128,
verbose: bool = False,
group_by_shape: bool = True,
track_image_sizes: bool = True,
**kwargs,
): ...3. Expose supported kwargs directly on generate() and stream_generate()
Right now, many supported options live only in generate_step() and are forwarded through **kwargs. That prevents static checkers from catching real mistakes.
A more checkable public signature would look something like:
def generate(
model: nn.Module,
processor: ProcessorLike,
prompt: str,
image: str | list[str] | None = None,
audio: str | list[str] | None = None,
*,
verbose: bool = False,
max_tokens: int = 256,
temperature: float = 0.0,
repetition_penalty: float | None = None,
repetition_context_size: int | None = 20,
top_p: float = 1.0,
min_p: float = 0.0,
top_k: int = 0,
logit_bias: dict[int, float] | None = None,
prompt_cache: list[Any] | None = None,
max_kv_size: int | None = None,
kv_bits: int | None = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prefill_step_size: int | None = None,
skip_special_tokens: bool = False,
resize_shape: tuple[int, int] | None = None,
eos_tokens: list[int] | list[str] | None = None,
thinking_budget: int | None = None,
thinking_end_token: str = "</think>",
thinking_start_token: str | None = None,
enable_thinking: bool = False,
) -> GenerationResult: ...And similarly:
def stream_generate(
model: nn.Module,
processor: ProcessorLike,
prompt: str,
image: str | list[str] | None = None,
audio: str | list[str] | None = None,
*,
max_tokens: int = 256,
temperature: float = 0.0,
repetition_penalty: float | None = None,
repetition_context_size: int | None = 20,
top_p: float = 1.0,
min_p: float = 0.0,
top_k: int = 0,
logit_bias: dict[int, float] | None = None,
prompt_cache: list[Any] | None = None,
max_kv_size: int | None = None,
kv_bits: int | None = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prefill_step_size: int | None = None,
skip_special_tokens: bool = False,
resize_shape: tuple[int, int] | None = None,
eos_tokens: list[int] | list[str] | None = None,
thinking_budget: int | None = None,
thinking_end_token: str = "</think>",
thinking_start_token: str | None = None,
enable_thinking: bool = False,
) -> Iterator[GenerationResult]: ...4. Broaden or protocol-ize the processor type
The runtime contract seems broader than PreTrainedTokenizer. In practice the object behaves more like an AutoProcessor-like wrapper and may expose things such as:
tokenizerdetokenizer- possibly
chat_template
A protocol would probably describe the real contract better than a single tokenizer class.
For example:
from typing import Protocol, runtime_checkable, Any
@runtime_checkable
class ProcessorLike(Protocol):
tokenizer: Any
detokenizer: AnyEven a minimal protocol would be more accurate than a narrow tokenizer annotation if the runtime accepts multiple processor shapes.
Expected benefit
This would make Pylance/Pyright and other type checkers much more useful by:
- avoiding false positives around
stream_generate() - improving autocomplete for supported kwargs
- catching misspelled generation parameters
- helping downstream users avoid maintaining local stub patches