Add Teacache Support for LongCat Image#1487
Add Teacache Support for LongCat Image#1487alex-jw-brooks wants to merge 16 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 003d3b9f61
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| sp_size = module.parallel_config.sequence_parallel_size | ||
| get_forward_context().sequence_parallel_size = sp_size | ||
|
|
||
| hidden_states = module.x_embedder(hidden_states) |
There was a problem hiding this comment.
Preserve sequence-parallel sharding in LongCat extractor
In the SP case (sequence_parallel_size > 1), this code enables SP in the forward context but does not replicate the required LongCat preprocessing (chunking image hidden_states and RoPE by rank, as done in LongCatImageTransformer2DModel.forward). As a result, SP attention paths run on unsharded layouts, which yields invalid coefficient-collection behavior and can break distributed estimation runs.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Useful, but I think there are larger underlying problems in SP for this model at the moment (see #1556). I will investigate the fix for that as well, but see the same error with & without TeaCache at the moment, so open to any direction for how to handle it on this PR
d35e91d to
b0dd147
Compare
| # Explicitly use inference mode to avoid gradients since we | ||
| # are not creating the pipeline through the model runner | ||
| with torch.inference_mode(): | ||
| self.pipeline.forward(req) |
There was a problem hiding this comment.
A few small fixes were needed in this script to avoid OOMs in my env from gradients, and to handle bf16 since it can't be .numpy the tensors
lishunyang12
left a comment
There was a problem hiding this comment.
Left a couple comments. The extractor mostly mirrors the model forward correctly, but the first block runs twice on non-cached steps which seems unintentional.
| _, hs = first_block( | ||
| hidden_states=hidden_states, | ||
| encoder_hidden_states=encoder_hidden_states, | ||
| temb=temb, |
There was a problem hiding this comment.
This runs first_block(...) to get the modulated input, but then run_transformer_blocks() below iterates over all module.transformer_blocks (including [0]) again. So block 0 gets executed twice on every non-cached step.
The other extractors (e.g., qwen) avoid this by extracting the modulated input with just the lightweight norm call (block.img_mod(temb) + block.img_norm1(...)) without running the full block forward. Could you do something similar here, or at least start run_transformer_blocks from module.transformer_blocks[1:]?
There was a problem hiding this comment.
Good catch 😬 Thanks! Fixed the modulated input and reran the coefficient calculations
| pipeline.to(device) | ||
| return pipeline | ||
|
|
||
|
|
There was a problem hiding this comment.
Should this also be wrapped with set_default_torch_dtype(od_config.dtype) like BagelAdapter.load_pipeline was updated to do above?
There was a problem hiding this comment.
I had actually added a set_default_torch_dtype around the call to the load pipeline on the adapter instead of just putting it around the one line 🙂 the better way to do this is
loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config)
return loader.load_model(od_config=od_config, load_device=device)
because load_model will handle the device placement, put the model in eval mode, and handle the dtypes from the diffusion config. Updated both to avoid managing default dtypes manually and made sure the bagel one still runs
|
Hi @alex-jw-brooks 👋 Checking in on the Teacache support for LongCat Image PR — 12 days since last update. Any progress? Thanks! |
|
Hey @alex-jw-brooks — following up on the open threads from 2 weeks ago. The main concern is still the block 0 double execution + modulated input extraction. Looking more carefully: Could you take a look? |
|
Hey @hsliuustc0106 @lishunyang12, haven't forgotten about this PR, just paused it for a bit while fixing the sequence parallelism for this model to avoid copying things over here. I'll get back to it this afternoon and work on the comments, thanks for your patience 🙂 |
08396d7 to
53f39f2
Compare
| | **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ||
| | **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ||
| | **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ||
| | **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | |
There was a problem hiding this comment.
Columns for sequence parallel are out of date for LongCat; it does support ring attention and Ulysses SP.
Tested that teacache works with both types of SP on as part of this PR.
|
Hey @lishunyang12 @hsliuustc0106 thanks for the reviews - took another pass at this and added some additional info since I hadn't tested image edit yet originally. Ready for another look when you've got the bandwidth 🙂 |
53f39f2 to
5284249
Compare
|
Thanks for the update. Will re-review this week. |
lishunyang12
left a comment
There was a problem hiding this comment.
The extractor looks correct now — norm1 is called on the pre-block hidden_states, and all blocks run in run_transformer_blocks(). Previous concern is addressed.
| CacheContext with all information needed for generic caching | ||
| """ | ||
| # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward | ||
| from diffusers.models.modeling_outputs import Transformer2DModelOutput |
There was a problem hiding this comment.
This mutates fwd_context.split_text_embed_in_sp but never restores it. If the forward context is reused across timesteps, this side effect persists. Worth a comment or a reset in postprocess().
There was a problem hiding this comment.
Sure, I added a comment. For this model, we don't need to restore it because it should never be True, but IMO we should just remove this attribute from the repo for now, because I don't think it's currently ever expected to be True or even implemented for the True behavior. Will open a pr to discuss 🙂
e485ebb to
a25f48f
Compare
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
adbe997 to
9c0fe26
Compare
|
@wtomin @hsliuustc0106 could you please review this PR? I've refactored a bit to share some of the code for the estimators since LongCat / the new Stable Audio coefficient were basically the same, but it is ready for a look when you have a moment |
main(see fix).Example Outputs
For both text to image and image edit, teacache is the left one.
For Image edit (using the coffee image above):
Speed Benchmarks
With a thresh of .2, the speedup: is ~1.7x on an
h100; didn't benchmark edit, but speedup looked comparable when I ran a quick check after.Here is the full script I had used for testing, which can be used for reproduction for tti - it'll report the average speedup of 3 images