Skip to content

[Feature][Diffusion] Plugin-based sparse attention interface for DiT modules#2231

Open
zzhang-fr wants to merge 4 commits intovllm-project:mainfrom
zzhang-fr:main
Open

[Feature][Diffusion] Plugin-based sparse attention interface for DiT modules#2231
zzhang-fr wants to merge 4 commits intovllm-project:mainfrom
zzhang-fr:main

Conversation

@zzhang-fr
Copy link
Copy Markdown

@zzhang-fr zzhang-fr commented Mar 26, 2026

Purpose

Build plugin-based sparse attention interface for DiT modules.
Resolves #1568 (attention backend extension discussion).

Design

Unified sparse attention framework with plugin-based backend discovery.
vllm-omni owns only the protocol layer and registry. All kernel
implementations live in external pip-installable packages that
register via Python entry_points — zero vendor backends in
vllm-omni core.

Core (merged into vllm-omni)

  • SparseAttentionBackend ABC, SparsePatternSpec, SparseAttention
    module
  • Plugin registry: entry_points + explicit registration + class path
  • DiTSparseAttentionAdapter with sparsity schedules
  • DiffusionSparseAttnConfig + CLI
    (--sparse-attn-backend / --sparse-attn-topk / --sparse-attn-schedule)
  • DenseFallbackBackend for correctness comparison in tests

Wan 2.2 integration (merged into vllm-omni)

  • enable_sparse_attention() on WanTransformer3DModel
  • Pipeline wiring from OmniDiffusionConfig
  • No-op fallback when no backend is installed

Contrib plugin packages (NOT merged — reference only)

contrib/flashinfer-vllm-omni/ and contrib/sparge-vllm-omni/ are
included solely to demonstrate how external packages implement the
interface
. They serve as a concrete example for reviewers to validate
that the ABC design is actually usable. Neither package will be merged
into vllm-omni; they are intended to live in separate repositories
(flashinfer-vllm-omni, sparge-vllm-omni) maintained outside this
repo and installed independently via pip install.

A user who wants sparse attention does:

pip install sparge-vllm-omni   # external package, not in this repo
vllm-omni serve Wan2.2 --sparse-attn-backend spargeattn --sparse-attn-topk 0.5

No changes to vllm-omni core are needed to add a new backend.

Known Limitations (to be fixed before merge)

  • Sparse attention bypasses SP/ring pre/post communication in
    multi-GPU runs — will add explicit SP/ring incompatibility check
    and fallback
  • --sparse-attn-backend none incorrectly falls back to
    DenseFallbackBackend — will fix in enable_sparse_attention()
    guard
  • schedule config is parsed but not yet wired to step-level
    topk_ratio updates — will be addressed in follow-up PR

Test Plan

Benchmark on A100 80GB (Wan 2.2, 480x832, 81f, 40 steps) with
SpargeAttn via sparge-vllm-omni entry_point plugin.

Test Result

Backend Total Time Per-step Speedup
Dense (baseline) 785s 19.6s 1.00×
SpargeAttn topk=0.7 763s 19.1s 1.03×
SpargeAttn topk=0.5 677s 16.9s 1.16×
SpargeAttn topk=0.3 590s 14.8s 1.33×

98 tests (97 passed, 1 skipped).

Dense: https://github.com/user-attachments/assets/62a7abfd-758f-4ec0-958b-fa3de672f972
Topk 0.7 https://github.com/user-attachments/assets/a175f027-424b-413f-b6c4-0733c354476f
Topk 0.5 https://github.com/user-attachments/assets/016c8463-f0bc-4a69-a891-b07fbc784a68
Topk 0.3 https://github.com/user-attachments/assets/17d2d464-0205-4d2c-8683-bd2e768c11f3

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.

Unified sparse attention framework with plugin-based backend discovery.
External packages register via Python entry_points — no vendor backends
in vllm-omni. Only install and configure.

Core:
- SparseAttentionBackend ABC, SparsePatternSpec, SparseAttention module
- Plugin registry: entry_points + explicit registration + class path
- DiTSparseAttentionAdapter with sparsity schedules
- DiffusionSparseAttnConfig + CLI (--sparse-attn-backend/topk/schedule)
- DenseFallbackBackend for test-only correctness comparison

Wan 2.2 integration:
- enable_sparse_attention() on WanTransformer3DModel
- Pipeline wiring from OmniDiffusionConfig
- No-op fallback when no backend installed

Contrib plugin packages:
- contrib/flashinfer-vllm-omni/ (FlashInfer BSR block-sparse)
- contrib/sparge-vllm-omni/ (SpargeAttn dynamic top-k)

Benchmark on A100 80GB (Wan 2.2, 480x832, 81f, 40 steps):
  dense:     785s (19.6s/step)
  topk=0.7:  763s (1.03x)
  topk=0.5:  677s (1.16x)
  topk=0.3:  590s (1.33x)

98 tests (97 passed, 1 skipped).

Signed-off-by: Zhen Zhang <zhen.zhang.fr@huawei.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 058884a921

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +431 to +432
if self.sparse_attn is not None and attn_mask is None:
hidden_states = self.sparse_attn(query, key, value)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Route sparse self-attention through SP/ring communication

This direct call bypasses the existing Attention wrapper, which is where sequence/ring-parallel pre/post communication is performed (strategy.pre_attention/post_attention in attention/layer.py). When sparse attention is enabled in distributed runs (e.g., SP active or ring_degree > 1), attention is computed on local shards only, so outputs can be wrong even though the dense path is correct. Sparse mode should either reuse the same parallel strategy path or be explicitly disabled for those configs.

Useful? React with 👍 / 👎.

Comment on lines +54 to +58
backend_cls = get_sparse_attn_backend(sparse_attn_config)
if backend_cls is None:
from vllm_omni.diffusion.sparse_attn.backends.dense import DenseFallbackBackend

backend_cls = DenseFallbackBackend
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor 'none' backend by skipping dense fallback substitution

get_sparse_attn_backend documents/implements 'none' as a disable signal, but this branch converts any None backend into DenseFallbackBackend. Because the pipeline enables sparse attention whenever sparse_attn config exists, --sparse-attn-backend none still replaces Wan self-attention layers instead of leaving the native Attention path untouched. That makes the disable option ineffective and can unexpectedly change runtime behavior.

Useful? React with 👍 / 👎.

Comment on lines +105 to +106
meta = self._get_metadata(query, key, query.device)
return self._impl.forward(query, key, value, meta)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Apply denoising schedule when running SparseAttention

The CLI/config now accept sparse_attn.schedule, but this forward path always uses cached metadata from a fixed _pattern_spec and never performs step-aware updates. No production call site wires DiTSparseAttentionAdapter.begin_step, so schedule choices like conservative/aggressive currently have no effect during denoising. This silently breaks expected experiment behavior for users tuning schedule.

Useful? React with 👍 / 👎.


import torch

from vllm_omni.diffusion.sparse_attn.base import (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @zzhang-fr, thanks for the PR - is there a reason this belongs in Omni and not in vLLM, which already has its own patterns for attention backends?

IMO we should minimize new attention types that are Omni specific, especially if they need changes on a per model basis, since it is hard to maintain, and keeping cross feature compatibility is already challenging

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alex-jw-brooks Thanks — I'll address both points together in my reply to your second comment below.


[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I don't think it will be easy to maintain a contrib package with additional plugins; part of the reason to have plugins is that they are externally maintained to begin with, so in many ways, imo adding it directly here and integrating it into models defeats the purpose of it being a plugin.

Another way to approach enabling this could be to contribute to the plugin architecture in vLLM to allow this type of thing to be registered with no code changes in attention layers for individual models, so that it's usable if people choose to install it? cc @tzhouam @hsliuustc0106 @Gaohan123 in case any of you have thoughts

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alex-jw-brooks Thanks for the review and the thoughtful questions.

Why not vLLM's existing attention backend

DiT attention doesn't go through vLLM's attention stack at all — it runs entirely inside the
diffusion pipeline, which has no paged KV cache, no causal masking, and no chunked prefill.
Beyond the interface mismatch, sparse attention for diffusion has a lifecycle that has no
parallel in the LLM path: sparsity ratio adapts per denoising step, patterns are derived from
the H×W spatial layout of patches, and the same model may want dense attention in early steps
and aggressive sparsity later. None of this maps onto AttentionMetadata. This is documented
in the "Alternatives considered" section of RFC #2233 if you'd like the fuller reasoning.

Where you're right

Two of your points I agree with completely and will fix before merge:

The enable_sparse_attention() call in wan2_2_transformer.py is a per-model hook, which is
exactly the maintenance surface you're worried about. The better path is to wire this through
the existing pipeline hook system the same way cache_backend is applied today — a single call
in GPUWorker.load_model(), no model-specific code required.

The contrib/ directory should not be in this repo. It was included here to make the protocol
easier to review — having a concrete backend alongside the ABC lets reviewers validate that
the interface is actually usable. But they should live in separate external repos
(flashinfer-vllm-omni, sparge-vllm-omni) that install via entry_points, which is exactly
the plugin model this PR is designed around. I'll drop contrib/ from this PR.

Invitation to the RFC

The broader design question you're raising — where this belongs, whether contrib/ makes sense,
and how models should opt in — is exactly what RFC #2233 is meant to discuss. The Feedback
Period is open for two weeks and the first question listed is specifically about contrib/ and
separate repos. I'd really value your input there, especially given your concern about
cross-feature compatibility. cc @tzhouam @hsliuustc0106 @Gaohan123

@zzhang-fr zzhang-fr changed the title [Feature][Backend] Plugin-based sparse attention for DiT modules [Feature][Diffusion] Plugin-based sparse attention interface for DiT modules Mar 26, 2026
@lishunyang12
Copy link
Copy Markdown
Contributor

lishunyang12 commented Mar 27, 2026

I think you should wait for feedback from maintainers to disssus about value of this feature and maintainability before opening a PR. This is common practice as attn backend bas been altered. Otherwise, it is quite hard to be accepted.

Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should introduce contrib/ directory and we have implemented the registration for Attention backend. This PR is introducing one independent registry for sparse attention. And it's trying to make sparse default when users have installed it. But in most of scenarios, sparse attention could lead to accuracy regression. So it's better to make it optional instead of default.

Could we make SparseAttentionBackend inherit AttentionBackend?


# Sparse attention args
parser.add_argument(
"--sparse-attn-backend",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-omni serve Wan2.2 --sparse-attn-backend spargeattn --sparse-attn-topk 0.5
I don't see the necessarity to add --sparse-attn-backend, the previous design vllm-omni serve Wan2.2 --attn-backend spargeattn seems cleaner to me

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SamitHuang I've addressed this in my reply to @gcanlin's comment.
The short version: removing --sparse-attn-backend and routing through --attn-backend is
cleaner, but introduces a cross-attention scoping problem that needs to be
resolved first. Details in that thread.

@zzhang-fr
Copy link
Copy Markdown
Author

I don't think we should introduce contrib/ directory and we have implemented the registration for Attention backend. This PR is introducing one independent registry for sparse attention. And it's trying to make sparse default when users have installed it. But in most of scenarios, sparse attention could lead to accuracy regression. So it's better to make it optional instead of default.

Could we make SparseAttentionBackend inherit AttentionBackend?

@gcanlin Thanks for the detailed review — these are all valid points and I want to engage with them carefully.

On inheriting AttentionBackend

This is architecturally sound and I've prototyped it. The inheritance chain works:

SparseAttentionBackend(AttentionBackend)
SparseAttentionImpl(AttentionImpl[SparseAttentionMetadata])
SparseAttentionMetadata(AttentionMetadata)

VIDEO_SPARSE_ATTN becomes one line in DiffusionAttentionBackendEnum, the independent
registry is deleted, and --sparse-attn-backend is replaced by --attn-backend.

However, inheriting introduces a new problem the current design avoids

--attn-backend is a global setting. get_attn_backend() returns the same backend class
for every Attention instance in the model — including WanCrossAttention.attn (line 531
in wan2_2_transformer.py).

Cross-attention (Q from patch tokens, KV from text encoder) must remain fully dense.
Applying sparse attention there is semantically wrong and will degrade quality.

The current per-model enable_sparse_attention() hook exists precisely because it
targets only self-attention layers, leaving cross-attention untouched. With the
inheritance approach, we need an equivalent mechanism to achieve the same protection —
for example, passing a module_type parameter to every Attention(...) constructor call
and checking it against the backend's supported_module_types().

The complexity is equivalent; it just moves from one place to another:

Current PR: one enable_sparse_attention() per model (gcanlin's concern)
Inheritance: one module_type= per Attention() call (same surface, different form)

My proposal

I'm happy to go either direction, but I'd like to resolve this in RFC #2233 before
reworking the implementation. Specifically:

  1. Should we go with inheritance + module_type parameter?
  2. Or keep the current separate interface but address the other concerns
    (remove independent registry, remove --sparse-attn-backend, make sparse
    strictly opt-in via explicit config)?

On "sparse should not be default"

Completely agreed. The _auto_select() behavior in registry.py is wrong —
it activates sparse just because a plugin is installed. I'll fix this regardless
of which architectural direction we take: sparse attention must require explicit
opt-in via --attn-backend video_sparse_attn (or equivalent), never auto-activated.

On contrib/

Will remove from this PR. Reference implementations will move to separate repos.

Happy to hear from @tzhouam @hsliuustc0106 @SamitHuang on the module_type question
before I revise the implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC]: Discussing the extension of attention backend.

5 participants