Skip to content

Refresh TeaCache when num_inference_steps=None#2240

Open
alex-jw-brooks wants to merge 2 commits intovllm-project:mainfrom
alex-jw-brooks:flux2_tc_fix
Open

Refresh TeaCache when num_inference_steps=None#2240
alex-jw-brooks wants to merge 2 commits intovllm-project:mainfrom
alex-jw-brooks:flux2_tc_fix

Conversation

@alex-jw-brooks
Copy link
Contributor

Purpose

Related to #2194

The proper fix for the above issue is to merge the sampling params to get the correct num_inference_steps, but this PR adds a short-term workaround for teacache, which doesn't depend on num_inference_steps. It also adds logging if the cache fails to reset for now while I am working on the more general fix.

This is needed because the warmup initializes teacache, which replaces forward(), and can cause bad behaviors when running TTI on models that accept image inputs. E.g., for Flux2Klein

from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

if __name__ == "__main__":
    omni = Omni(
        model="black-forest-labs/FLUX.2-klein-4B",
        cache_backend="tea_cache",
    )

    prompt = "A cat sitting on a windowsill"

   # If you specify num_inference_steps, you will see the second cache refresh (after warmup)
   # but if you don't pass it, you won't since refresh won't be called.
    sampling_params = OmniDiffusionSamplingParams(
        # Not specifying num_inference_steps will crash forward
    )

    outputs = omni.generate(prompt, sampling_params)
    outputs[0].images[0].save("meow.png")

Not refreshing before entering the forward pass will blow up because the new modulated inputs don't have an image component, while the previous (stale) ones do.

ERROR 03-26 18:11:45 [diffusion_worker.py:481]   File "/home/alex-jw-brooks/vllm-omni/vllm_omni/diffusion/cache/teacache/hook.py", line 222, in _should_compute_full_transformer
ERROR 03-26 18:11:45 [diffusion_worker.py:481]     (modulated_inp - state.previous_modulated_input).abs().mean()
ERROR 03-26 18:11:45 [diffusion_worker.py:481]      ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ERROR 03-26 18:11:45 [diffusion_worker.py:481] RuntimeError: The size of tensor a (4096) must match the size of tensor b (8192) at non-singleton dimension 1
...
ERROR 03-26 18:11:45 [stage_diffusion_client.py:78]   File "/home/alex-jw-brooks/vllm-omni/vllm_omni/entrypoints/async_omni_diffusion.py", line 309, in generate
ERROR 03-26 18:11:45 [stage_diffusion_client.py:78]     raise RuntimeError(f"Diffusion generation failed: {e}") from e
ERROR 03-26 18:11:45 [stage_diffusion_client.py:78] RuntimeError: Diffusion generation failed: The size of tensor a (4096) must match the size of tensor b (8192) at non-singleton dimension 1

This PR allows teacache to refresh in this case, and adds a log if we can't refresh the cache while the more correct fix is added.

@Gaohan123 @wtomin @fhfuih could you please take a look?

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@fhfuih
Copy link
Contributor

fhfuih commented Mar 27, 2026

EDIT: Sorry, I actually missed your PR description. My understanding is correct, jumped right into your code 😂

Thanks for the PR. A quick question: if I understand it correctly, this PR is only a quick fix It force set the number of inference steps to 0: not None but falsy. This passes the check during cache refreshing, and also yields to pipeline-specific overrides.

And a more complete fix is at your cache_refresh branch

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Mar 27, 2026

Hey @fhfuih! No worries 😆 but yes. My understanding of the flow is

  • The TeaCache hooks gets initialized in load_model, which also creates the StateManager etc for the cache
  • When we run requests, we run the _WrappedForward, which calls the hook's new forward (here).
  • The new forward for TeaCache (this) runs the extractor, then it gets the TeaCache state or creates a new one. After that, it checks the state here to see if it's the first timestep, and compares against the previous modulated state if it isn't.

For TeaCache, the refresh does not depend on the timesteps, and is just resetting the TeaCache state (i.e., the num_inference_steps aren't passed anywhere here). So the value of 0 is just a placeholder I chose because the arg is an int, but in the TeaCache case doesn't matter since all it's doing is clearing the state.

Since it's not being called currently, the state is stale from the last execute model call, so instead of creating a new one on the first time step, it gets the old one, so we fall through this check.

So this fix is okay for a short-term fix for the behavior for TeaCache, but the other branch will fix it more correctly by passing the actual num_inference_steps value, which we need to be able to reset DiTCache correctly 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants