-
Notifications
You must be signed in to change notification settings - Fork 637
Description
Plugin-based Sparse Attention Interface for DiT Modules
Motivation
vLLM-Omni currently runs DiT module attention (self-attention among latent patches, and cross-attention between patches and condition embeddings) with a standard full-attention kernel. As model resolution and video length scale up, the quadratic cost of full attention becomes the dominant bottleneck: a 720p video DiT with 8192 patch tokens spends more than 60% of step latency in attention alone.
Several sparse attention libraries have emerged that can cut this cost substantially:
- SpargeAttn (thu-ml/SpargeAttn, ICML 2025): training-free, block-level top-k prediction. Demonstrated 2–3× speedup on FLUX and HunyuanVideo at <0.1% quality degradation.
- FlashInfer sparse kernels:
BlockSparseAttentionWrapperwith BSR-format masks, already present as an optional vLLM dependency. Supports per-block and sub-block masks, CUDA-graph-safe. - RainFusion (arXiv 2505.21036): adaptive head-level classification combined with spatiotemporal token permutation, targeting video DiT workloads.
These libraries have incompatible calling conventions, and none accounts for DiT-specific structure: fixed sequence lengths, multi-step denoising loops, spatial patch layout, and joint MMDiT attention patterns.
This RFC proposes a plugin-based architecture where vLLM-Omni provides only a unified interface and registry. All sparse attention backends — including FlashInfer and SpargeAttn — live in external pip-installable packages that register via Python entry points.
Implementation Status
This RFC has been implemented and benchmarked. See PR 2231
Proposed Change
Overview
The design is organized around three layers. vLLM-Omni owns only the protocol and registry. All backends are external plugins.
┌─────────────────────────────────────────────────────────────────────┐
│ External plugin packages │
│ flashinfer-vllm-omni · sparge-vllm-omni · custom-backend │
│ (pip install, auto-register via entry_points) │
└────────────────────────────────┬────────────────────────────────────┘
│ implements SparseAttentionBackend ABC
┌────────────────────────────────▼────────────────────────────────────┐
│ vllm-omni core protocol layer │
│ SparsePatternSpec · SparseAttentionBackend (ABC) │
│ SparseAttentionImpl (ABC) · SparseMetadataBuilder (ABC) │
│ SparseAttentionMetadata · DiffusionSparseAttnConfig │
│ register_sparse_attn_backend() · get_sparse_attn_backend() │
│ SparseAttention(nn.Module) · DiTSparseAttentionAdapter │
└─────────────────────────────────────────────────────────────────────┘
│ integrates with
┌────────────────────────────────▼────────────────────────────────────┐
│ vllm-omni model layer │
│ WanTransformer3DModel.enable_sparse_attention() │
│ Wan22Pipeline (wires od_config.sparse_attn → transformer) │
│ CLI: --sparse-attn-backend / --sparse-attn-topk / --schedule │
└─────────────────────────────────────────────────────────────────────┘
1. Sparse pattern descriptor
SparsePatternSpec is a backend-agnostic dataclass describing the sparsity pattern. Each backend's metadata builder translates it to its native format.
@dataclass
class SparsePatternSpec:
pattern_type: SparsePatternType # BLOCK_SPARSE | SLIDING_WINDOW |
# SPATIAL_AWARE | DYNAMIC_TOPK |
# ARROW | COMPOSITE | DENSE
# Block-sparse fields
block_size_q: int = 128
block_size_kv: int = 64
block_mask: Optional[torch.Tensor] = None # (H, Qb, KVb)
bsr_indptr: Optional[torch.Tensor] = None
bsr_indices: Optional[torch.Tensor] = None
sub_block_mask: Optional[torch.Tensor] = None # (nnz, Bq, Bkv)
# Sliding window
window_size: Optional[int] = None
# Spatial-aware (DiT patch grid)
spatial_layout: Optional[SpatialLayout] = None
spatial_radius: Optional[int] = None
# Dynamic top-k
topk_ratio: Optional[float] = None
# MMDiT arrow attention
num_image_tokens: Optional[int] = None
num_text_tokens: Optional[int] = None
i2i_pattern: Optional[SparsePatternSpec] = None
cross_dense: bool = True
# Denoising step context (set by DiTSparseAttentionAdapter)
current_step: Optional[int] = None
total_steps: Optional[int] = NoneSpatialLayout captures the patch grid geometry:
@dataclass(frozen=True)
class SpatialLayout:
height: int
width: int
frames: int = 1
patch_size: int = 2
ordering: Literal["raster", "hilbert", "z_curve"] = "raster"File: vllm_omni/diffusion/sparse_attn/pattern.py
2. Backend protocol (ABC)
Each backend is a factory class with static methods. The plan → forward lifecycle maps to FlashInfer's plan() / run() and is compatible with SpargeAttn's stateless functional calls.
class SparseAttentionBackend(ABC):
@staticmethod
@abstractmethod
def get_name() -> str: ...
@staticmethod
@abstractmethod
def get_impl_cls() -> Type[SparseAttentionImpl]: ...
@staticmethod
@abstractmethod
def get_metadata_builder_cls() -> Type[SparseMetadataBuilder]: ...
@staticmethod
@abstractmethod
def supported_patterns() -> set[SparsePatternType]: ...
@staticmethod
@abstractmethod
def supported_module_types() -> set[AttentionModuleType]: ...
@staticmethod
def validate_config(config: "DiffusionSparseAttnConfig") -> None:
pass # default: no-op
class SparseAttentionImpl(ABC):
@abstractmethod
def forward(
self,
query: torch.Tensor, # (B, S, H, D) layout
key: torch.Tensor,
value: torch.Tensor,
metadata: SparseAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ...
class SparseMetadataBuilder(ABC):
@abstractmethod
def plan(
self,
pattern_spec: SparsePatternSpec,
module_type: AttentionModuleType,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
seq_len_q: int,
seq_len_kv: int,
batch_size: int = 1,
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device("cuda"),
) -> SparseAttentionMetadata: ...
def update_step(
self, metadata: SparseAttentionMetadata,
step: int, total_steps: int,
) -> SparseAttentionMetadata:
return metadata # default: no-opFile: vllm_omni/diffusion/sparse_attn/base.py
3. Plugin-based backend registry
Backends are discovered via three mechanisms, in priority order:
-
Entry points — external packages declare in
pyproject.toml:[project.entry-points."vllm_omni.sparse_attn"] flashinfer = "flashinfer_vllm_omni.backend:FlashInferSparseBackend"
-
Explicit registration — call at import time:
from vllm_omni.diffusion.sparse_attn.registry import register_sparse_attn_backend register_sparse_attn_backend("mybackend", "my_pkg.module:MyBackend")
-
Full class path — pass directly in config/CLI:
--sparse-attn-backend "my_pkg.module.MyBackend"
The registry has no hardcoded external backends. The only built-in is DenseFallbackBackend for test correctness comparisons.
def get_sparse_attn_backend(
sparse_attn_config: Optional[DiffusionSparseAttnConfig],
) -> Optional[Type[SparseAttentionBackend]]:
"""Resolution priority:
1. DIFFUSION_SPARSE_ATTN_BACKEND env var
2. sparse_attn_config.backend
3. Auto-select from installed entry_point backends
4. None (use default dense Attention layer)
"""File: vllm_omni/diffusion/sparse_attn/registry.py
4. Configuration
DiffusionSparseAttnConfig mirrors the existing DiffusionCacheConfig pattern:
@dataclass
class DiffusionSparseAttnConfig:
backend: str = "auto"
pattern_type: str = "dynamic_topk"
topk_ratio: float = 0.5
block_size_q: int = 128
block_size_kv: int = 64
window_size: int | None = None
spatial_radius: int | None = None
schedule: str = "constant" # "constant" | "conservative" | "aggressive"Integrated into OmniDiffusionConfig:
class OmniDiffusionConfig:
sparse_attn_backend: str | None = None # shorthand
sparse_attn: DiffusionSparseAttnConfig | dict | None = NoneCLI arguments (in vllm_omni/entrypoints/cli/serve.py):
--sparse-attn-backend Backend name or full class path
--sparse-attn-topk Top-k ratio (0-1)
--sparse-attn-schedule Sparsity schedule: constant | conservative | aggressive
Environment variable fallback: DIFFUSION_SPARSE_ATTN_BACKEND
Files: vllm_omni/diffusion/data.py, vllm_omni/entrypoints/cli/serve.py, vllm_omni/engine/async_omni_engine.py
5. DiT-specific adapter
DiTSparseAttentionAdapter manages denoising-step-aware sparsity scheduling:
class DiTSparseAttentionAdapter:
def __init__(
self,
backend_cls: Type[SparseAttentionBackend],
module_type: AttentionModuleType,
num_heads: int, num_kv_heads: int, head_dim: int,
sparse_attn_config: DiffusionSparseAttnConfig,
): ...
def begin_step(self, step: int, total_steps: int) -> None: ...
def forward(self, query, key, value) -> torch.Tensor: ...
def reset(self) -> None: ...Built-in sparsity schedules:
| Schedule | Behavior |
|---|---|
constant |
Uses topk_ratio from config at every step |
conservative |
Full attention early (t<0.2), ramp to 70% sparsity |
aggressive |
80% sparsity early, decrease to 50% |
Custom schedules can be registered via register_sparsity_schedule(name, fn).
File: vllm_omni/diffusion/sparse_attn/dit_adapter.py
6. Drop-in SparseAttention module
class SparseAttention(nn.Module):
def __init__(
self,
num_heads: int,
head_dim: int,
num_kv_heads: Optional[int] = None,
module_type: AttentionModuleType = AttentionModuleType.DIT_SELF,
sparse_attn_config: Optional[DiffusionSparseAttnConfig] = None,
dtype: torch.dtype = torch.float16,
): ...
def forward(self, query, key, value) -> torch.Tensor: ...Metadata is cached by (batch_size, seq_len_q, seq_len_kv, device) key to avoid re-planning across layers with the same shape.
File: vllm_omni/diffusion/sparse_attn/attention.py
7. Wan 2.2 model integration
WanTransformer3DModel.enable_sparse_attention() replaces self-attention layers post-construction:
def enable_sparse_attention(self, sparse_attn_config):
backend_cls = get_sparse_attn_backend(sparse_attn_config)
if backend_cls is None:
return # no-op: use default dense Attention layer
for block in self.blocks:
block.attn1.sparse_attn = SparseAttention(
num_heads=..., head_dim=...,
sparse_attn_config=sparse_attn_config,
)WanSelfAttention.forward() dispatches:
if self.sparse_attn is not None and attn_mask is None:
output = self.sparse_attn(query, key, value)
else:
output = self.attn(query, key, value, attn_metadata) # existing denseWhen attn_mask is present (e.g., SP auto-padding), sparse attention is bypassed with a logged warning.
The pipeline calls enable_sparse_attention() after transformer construction:
# In Wan22Pipeline.__init__():
if od_config.sparse_attn is not None:
self.transformer.enable_sparse_attention(od_config.sparse_attn)Files: vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py, vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
8. File layout
vllm_omni/diffusion/
sparse_attn/
__init__.py # Public API exports
pattern.py # SparsePatternSpec, SparsePatternType, SpatialLayout
base.py # SparseAttentionBackend, SparseAttentionImpl,
# SparseMetadataBuilder, SparseAttentionMetadata,
# AttentionModuleType
registry.py # Plugin registry: entry_points + register + resolve
attention.py # SparseAttention(nn.Module)
dit_adapter.py # DiTSparseAttentionAdapter + sparsity schedules
backends/
dense.py # DenseFallbackBackend (test-only correctness oracle)
data.py # DiffusionSparseAttnConfig (added)
models/wan2_2/
wan2_2_transformer.py # enable_sparse_attention() (added)
pipeline_wan2_2.py # Pipeline wiring (added)
contrib/ # External plugin packages (reference implementations)
flashinfer-vllm-omni/ # FlashInfer BSR block-sparse backend
sparge-vllm-omni/ # SpargeAttn dynamic top-k backend
tests/diffusion/sparse_attn/
test_protocol.py # ABC, config, registry tests
test_config_cli.py # CLI wiring, env var, priority resolution
test_dit_adapter.py # Schedule functions, adapter lifecycle
test_wan22_sparse.py # Wan 2.2 model integration (CPU, mocked)
test_wan22_gpu_integration.py # GPU tests with FlashInfer plugin
benchmarks/
sparse_attn_benchmark.py # Kernel-level micro-benchmark
wan22_sparse_benchmark.py # End-to-end Wan 2.2 benchmark
examples/offline_inference/text_to_video/
wan2_2_sparse.py # Wan 2.2 sparse attention example
9. Creating a plugin backend
External packages implement SparseAttentionBackend and register via entry_points:
pyproject.toml:
[project.entry-points."vllm_omni.sparse_attn"]
mybackend = "my_package.backend:MyBackendClass"my_package/backend.py:
from vllm_omni.diffusion.sparse_attn.base import (
SparseAttentionBackend, SparseAttentionImpl, SparseMetadataBuilder,
)
class MyImpl(SparseAttentionImpl):
def forward(self, query, key, value, metadata, output=None):
# Input: (B, S, H, D), implement your kernel here
...
class MyBuilder(SparseMetadataBuilder):
def plan(self, pattern_spec, module_type, num_qo_heads, ...):
...
class MyBackendClass(SparseAttentionBackend):
@staticmethod
def get_name(): return "mybackend"
@staticmethod
def get_impl_cls(): return MyImpl
@staticmethod
def get_metadata_builder_cls(): return MyBuilder
@staticmethod
def supported_patterns(): return {SparsePatternType.DYNAMIC_TOPK}
@staticmethod
def supported_module_types(): return {AttentionModuleType.DIT_SELF}Usage:
pip install my-sparse-backend
vllm-omni serve Model --omni --sparse-attn-backend mybackend --sparse-attn-topk 0.510. Usage examples
CLI (serving):
vllm-omni serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni \
--sparse-attn-backend spargeattn \
--sparse-attn-topk 0.5 \
--sparse-attn-schedule conservativeOffline inference:
from vllm_omni.diffusion.data import DiffusionSparseAttnConfig
from vllm_omni.entrypoints.omni import Omni
omni = Omni(
model="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
sparse_attn=DiffusionSparseAttnConfig(
backend="spargeattn",
topk_ratio=0.5,
schedule="conservative",
),
)
result = omni.generate({"prompt": "A cat playing piano"}, sampling_params)Environment variable:
export DIFFUSION_SPARSE_ATTN_BACKEND=spargeattn
# Backend auto-discovered from installed entry_pointsFull class path (development):
--sparse-attn-backend "my_dev_pkg.sparse.MyExperimentalBackend"What does NOT change
- The AR/LLM backbone attention path (
vllm_omni/attention/) is untouched. This RFC covers DiT modules only. - Existing full-attention DiT models continue to work unchanged. When no sparse backend is installed,
enable_sparse_attention()is a no-op and the existingAttentionlayer handles dense attention with full SP, ring attention, and platform dispatch. - vLLM-Omni does not take a hard dependency on any sparse attention library. All backends are external packages discovered via entry_points.
Alternatives considered
- In-tree vendor backends: rejected. Tying kernel implementations to vLLM-Omni means every new kernel requires a core PR and breaks when upstream APIs change. Plugin architecture keeps core clean.
- Reusing vLLM's existing
AttentionBackend: DiT modules do not use paged KV cache, causal masking, or chunked prefill. The existing attention system is tightly coupled to dense patterns with SP, ring attention, and platform dispatch. A separate sparse_attn module avoids contaminating the existing attention path. - FlashInfer-only implementation: FlashInfer does not support SpargeAttn's dynamic top-k prediction or RainFusion's head classification, which are the primary accuracy-efficiency levers for video DiT workloads.
Benchmark Results
Environment
- GPU: NVIDIA A100 80GB PCIe
- Model: Wan-AI/Wan2.2-T2V-A14B-Diffusers (14B parameters)
- Backend: SpargeAttn (via
sparge-vllm-omnientry_point plugin)
End-to-end (480x832, 81 frames, 40 denoising steps)
| Run | Total Time | Per-step | Speedup |
|---|---|---|---|
| Dense FA3 (baseline) | 785s | 19.6s | 1.00x |
| SpargeAttn topk=0.7 | 763s | 19.1s | 1.03x |
| SpargeAttn topk=0.5 | 677s | 16.9s | 1.16x |
| SpargeAttn topk=0.3 | 590s | 14.7s | 1.33x |
Kernel-level micro-benchmark (attention-only, 40 heads × 128 head_dim)
| Config | SeqLen | Dense FA3 | Sparge topk=0.5 | Sparge topk=0.3 | Sparge topk=0.2 |
|---|---|---|---|---|---|
| 480p 17f | 6,630 | 10.8ms | 10.1ms (1.1x) | 7.6ms (1.4x) | 6.4ms (1.7x) |
| 480p 33f | 12,870 | 39.5ms | 31.5ms (1.3x) | 22.1ms (1.8x) | 17.1ms (2.3x) |
| 720p 33f | 29,700 | 206.1ms | 146.7ms (1.4x) | 95.6ms (2.2x) | 70.0ms (2.9x) |
Testing
98 tests (97 passed, 1 skipped when SpargeAttn not installed):
| Test File | Tests | Coverage |
|---|---|---|
test_protocol.py |
32 | Pattern spec, config, registry, SparseAttention module |
test_config_cli.py |
19 | CLI args, env var, priority resolution, end-to-end flow |
test_dit_adapter.py |
25 | Schedules, adapter lifecycle, GPU forward, metadata caching |
test_wan22_sparse.py |
7 | Wan 2.2 model integration, plugin registration |
test_wan22_gpu_integration.py |
5+1 | FlashInfer GPU tests, SpargeAttn placeholder |
Feedback Period
Two weeks from the date this issue is opened.
Key questions for reviewers:
- Should the
contrib/plugin packages ship in the vllm-omni repo, or move to separate repositories? - Is the entry_point group name
vllm_omni.sparse_attnappropriate, or should it be scoped differently? - Should
DiTSparseAttentionAdapterbe exposed as a first-class public API, or remain internal? - Should the
DenseFallbackBackendbe kept for test correctness comparisons, or removed entirely?
CC List.
@wtomin @ZJY0516 @hsliuustc0106 @jiangmengyu18 @gglorian @gcanlin
Any Other Things.
Related issues and PRs:
- Issue [RFC]: Discussing the extension of attention backend. #1568: [RFC] Discussing the extension of attention backend (this RFC addresses all three problems raised there)
- PR [Model] Add SpargeAttentionBackend for Wan 2.2 #888: Add SpargeAttentionBackend for Wan 2.2 (migrated and generalized by Phase 2)
- Issue [RFC]: vLLM-Omni NPU 2026 Q1 Roadmap #886: vLLM-Omni NPU 2026 Q1 Roadmap (Phase 4 unblocks the MindIE-SD sparse attention item)
- Issue [RFC]: Diffusion Models Features Supports Plan #814: Diffusion Models Acceleration Supports Plan
- Issue [RFC]: Diffusion Acceleration API design #158: Diffusion Acceleration API design
- Issue [Diffusion][Attention] sage attention backend #243: Sage Attention (merged; unaffected by this RFC)
Related work and references:
- SpargeAttn: https://github.com/thu-ml/SpargeAttn (ICML 2025)
- FlashInfer sparse API: https://docs.flashinfer.ai/api/sparse.html
- RainFusion: https://arxiv.org/abs/2505.21036
- vLLM plugin architecture RFC: [RFC]: Enhancing vLLM Plugin Architecture vllm#19161