Skip to content

Add Teacache Support for LongCat Image#1487

Open
alex-jw-brooks wants to merge 16 commits intovllm-project:mainfrom
alex-jw-brooks:longcat_teacache
Open

Add Teacache Support for LongCat Image#1487
alex-jw-brooks wants to merge 16 commits intovllm-project:mainfrom
alex-jw-brooks:longcat_teacache

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

@alex-jw-brooks alex-jw-brooks commented Feb 25, 2026

  • Enables TeaCache support for LongCat Image. The model coefficients and speedups were calculated with the current config, not main (see fix).
  • Includes some fixes to the coefficient estimator to avoid computing gradients and avoid dtype casting issues from running bf16 models
  • Updates docs to add some notes on estimating coefficients for models that have layers that require vLLM's fwd context and parallel groups to be set up, since it was needed for this one

Example Outputs

For both text to image and image edit, teacache is the left one.

$ python text_to_image.py --cache-backend tea_cache --model meituan-longcat/LongCat-Image --output coffee_tc.png
$ python text_to_image.py --model meituan-longcat/LongCat-Image --output coffee.png
coffee_tc coffee

For Image edit (using the coffee image above):

$ python image_edit.py --model meituan-longcat/LongCat-Image-Edit --image coffee.png --prompt "make the coffee cup transparent"  --cache-backend tea_cache --output edit_coffee_tc.png
$ python image_edit.py --model meituan-longcat/LongCat-Image-Edit --image coffee.png --prompt "make the coffee cup transparent" --output edit_coffee.png
edit_coffee_tc edit_coffee

Speed Benchmarks

With a thresh of .2, the speedup: is ~1.7x on an h100; didn't benchmark edit, but speedup looked comparable when I ran a quick check after.

Here is the full script I had used for testing, which can be used for reproduction for tti - it'll report the average speedup of 3 images

import os
import gc
import time
import torch
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

# Configuration
MODEL_ID = "meituan-longcat/LongCat-Image"
PROMPT = "A cup of coffee sitting on a table."
STEPS = 50
SEEDS = [444, 111, 3919]

TEACACHE_DIR = "cache_results"
NO_CACHE_DIR = "no_cache_results"
os.makedirs(TEACACHE_DIR, exist_ok=True)
os.makedirs(NO_CACHE_DIR, exist_ok=True)


def run_benchmark(use_cache=False):
    print(f"\n{'Testing with TeaCache' if use_cache else 'Testing without TeaCache'}...")
    times = []
    # Configure cache based on requirement
    out_dir = TEACACHE_DIR if use_cache else NO_CACHE_DIR
    cache_config = {
        "rel_l1_thresh": 0.2,
    } if use_cache else {}
    cache_backend = "tea_cache" if use_cache else None

    omni = Omni(
        model=MODEL_ID,
        cache_backend=cache_backend,
        cache_config=cache_config,
        dtype="bfloat16",
    )

    for seed in SEEDS:
        sampling_params = OmniDiffusionSamplingParams(num_inference_steps=STEPS, seed=seed)
        start = time.time()
        outputs = omni.generate(PROMPT, sampling_params)
        end = time.time()
        run_time = end - start
        times.append(run_time)
        # Save the generated image
        image = outputs[0].request_output[0].images[0]
        filename = f"{out_dir}/seed_{seed}.png"
        print(f"Run time: {run_time} for seed: {seed} [use_cache={use_cache}]")
        image.save(filename)

    avg_time = sum(times) / len(times)
    print(f"Average latency [use_cache={use_cache}]: {avg_time}")
    return avg_time


if __name__ == "__main__":
    # Run tests
    time_no_cache = run_benchmark(use_cache=False)
    torch.cuda.empty_cache()
    gc.collect()
    time_with_cache = run_benchmark(use_cache=True)

    print(f"\nResults:")
    print(f"Speedup: {time_no_cache / time_with_cache:.2f}x")

@alex-jw-brooks alex-jw-brooks changed the title [WIP] [WIP] Add Teacache Support for LongCat Image Feb 25, 2026
@alex-jw-brooks alex-jw-brooks marked this pull request as ready for review February 27, 2026 17:48
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 003d3b9f61

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +584 to +587
sp_size = module.parallel_config.sequence_parallel_size
get_forward_context().sequence_parallel_size = sp_size

hidden_states = module.x_embedder(hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve sequence-parallel sharding in LongCat extractor

In the SP case (sequence_parallel_size > 1), this code enables SP in the forward context but does not replicate the required LongCat preprocessing (chunking image hidden_states and RoPE by rank, as done in LongCatImageTransformer2DModel.forward). As a result, SP attention paths run on unsharded layouts, which yields invalid coefficient-collection behavior and can break distributed estimation runs.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful, but I think there are larger underlying problems in SP for this model at the moment (see #1556). I will investigate the fix for that as well, but see the same error with & without TeaCache at the moment, so open to any direction for how to handle it on this PR

@alex-jw-brooks alex-jw-brooks changed the title [WIP] Add Teacache Support for LongCat Image Add Teacache Support for LongCat Image Feb 27, 2026
# Explicitly use inference mode to avoid gradients since we
# are not creating the pipeline through the model runner
with torch.inference_mode():
self.pipeline.forward(req)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small fixes were needed in this script to avoid OOMs in my env from gradients, and to handle bf16 since it can't be .numpy the tensors

Copy link
Copy Markdown
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple comments. The extractor mostly mirrors the model forward correctly, but the first block runs twice on non-cached steps which seems unintentional.

_, hs = first_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs first_block(...) to get the modulated input, but then run_transformer_blocks() below iterates over all module.transformer_blocks (including [0]) again. So block 0 gets executed twice on every non-cached step.

The other extractors (e.g., qwen) avoid this by extracting the modulated input with just the lightweight norm call (block.img_mod(temb) + block.img_norm1(...)) without running the full block forward. Could you do something similar here, or at least start run_transformer_blocks from module.transformer_blocks[1:]?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 😬 Thanks! Fixed the modulated input and reran the coefficient calculations

pipeline.to(device)
return pipeline


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be wrapped with set_default_torch_dtype(od_config.dtype) like BagelAdapter.load_pipeline was updated to do above?

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had actually added a set_default_torch_dtype around the call to the load pipeline on the adapter instead of just putting it around the one line 🙂 the better way to do this is

        loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config)
        return loader.load_model(od_config=od_config, load_device=device)

because load_model will handle the device placement, put the model in eval mode, and handle the dtypes from the diffusion config. Updated both to avoid managing default dtypes manually and made sure the bagel one still runs

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Hi @alex-jw-brooks 👋

Checking in on the Teacache support for LongCat Image PR — 12 days since last update. Any progress?

Thanks!

@lishunyang12
Copy link
Copy Markdown
Contributor

Hey @alex-jw-brooks — following up on the open threads from 2 weeks ago. The main concern is still the block 0 double execution + modulated input extraction.

Looking more carefully: first_block.norm1(hs, emb=temb)[0] extracts the modulated input from hs (the post-block-0 output), but it should be from the pre-block hidden_states. The Qwen extractor does this correctly — it calls block.img_norm1(hidden_states, img_mod1) on the original hidden states without running the full block. This means the cache decisions here are based on the wrong signal, and the coefficients were estimated with that bug.

Could you take a look?

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

Hey @hsliuustc0106 @lishunyang12, haven't forgotten about this PR, just paused it for a bit while fixing the sequence parallelism for this model to avoid copying things over here. I'll get back to it this afternoon and work on the comments, thanks for your patience 🙂

@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 14, 2026
@alex-jw-brooks alex-jw-brooks force-pushed the longcat_teacache branch 2 times, most recently from 08396d7 to 53f39f2 Compare March 15, 2026 00:18
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Columns for sequence parallel are out of date for LongCat; it does support ring attention and Ulysses SP.

Tested that teacache works with both types of SP on as part of this PR.

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Mar 15, 2026

Hey @lishunyang12 @hsliuustc0106 thanks for the reviews - took another pass at this and added some additional info since I hadn't tested image edit yet originally. Ready for another look when you've got the bandwidth 🙂

@lishunyang12
Copy link
Copy Markdown
Contributor

Thanks for the update. Will re-review this week.

Copy link
Copy Markdown
Contributor

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extractor looks correct now — norm1 is called on the pre-block hidden_states, and all blocks run in run_transformer_blocks(). Previous concern is addressed.

CacheContext with all information needed for generic caching
"""
# TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward
from diffusers.models.modeling_outputs import Transformer2DModelOutput
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mutates fwd_context.split_text_embed_in_sp but never restores it. If the forward context is reused across timesteps, this side effect persists. Worth a comment or a reset in postprocess().

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I added a comment. For this model, we don't need to restore it because it should never be True, but IMO we should just remove this attribute from the repo for now, because I don't think it's currently ever expected to be True or even implemented for the True behavior. Will open a pr to discuss 🙂

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

@wtomin @hsliuustc0106 could you please review this PR?

I've refactored a bit to share some of the code for the estimators since LongCat / the new Stable Audio coefficient were basically the same, but it is ready for a look when you have a moment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants