Skip to content

[RFC]: Plugin-based Sparse Attention Interface for DiT Modules #2233

@zzhang-fr

Description

@zzhang-fr

Plugin-based Sparse Attention Interface for DiT Modules

Motivation

vLLM-Omni currently runs DiT module attention (self-attention among latent patches, and cross-attention between patches and condition embeddings) with a standard full-attention kernel. As model resolution and video length scale up, the quadratic cost of full attention becomes the dominant bottleneck: a 720p video DiT with 8192 patch tokens spends more than 60% of step latency in attention alone.

Several sparse attention libraries have emerged that can cut this cost substantially:

  • SpargeAttn (thu-ml/SpargeAttn, ICML 2025): training-free, block-level top-k prediction. Demonstrated 2–3× speedup on FLUX and HunyuanVideo at <0.1% quality degradation.
  • FlashInfer sparse kernels: BlockSparseAttentionWrapper with BSR-format masks, already present as an optional vLLM dependency. Supports per-block and sub-block masks, CUDA-graph-safe.
  • RainFusion (arXiv 2505.21036): adaptive head-level classification combined with spatiotemporal token permutation, targeting video DiT workloads.

These libraries have incompatible calling conventions, and none accounts for DiT-specific structure: fixed sequence lengths, multi-step denoising loops, spatial patch layout, and joint MMDiT attention patterns.

This RFC proposes a plugin-based architecture where vLLM-Omni provides only a unified interface and registry. All sparse attention backends — including FlashInfer and SpargeAttn — live in external pip-installable packages that register via Python entry points.

Implementation Status

This RFC has been implemented and benchmarked. See PR 2231

Proposed Change

Overview

The design is organized around three layers. vLLM-Omni owns only the protocol and registry. All backends are external plugins.

┌─────────────────────────────────────────────────────────────────────┐
│                   External plugin packages                          │
│   flashinfer-vllm-omni  ·  sparge-vllm-omni  ·  custom-backend     │
│   (pip install, auto-register via entry_points)                     │
└────────────────────────────────┬────────────────────────────────────┘
                                 │ implements SparseAttentionBackend ABC
┌────────────────────────────────▼────────────────────────────────────┐
│                  vllm-omni core protocol layer                      │
│   SparsePatternSpec · SparseAttentionBackend (ABC)                  │
│   SparseAttentionImpl (ABC) · SparseMetadataBuilder (ABC)           │
│   SparseAttentionMetadata · DiffusionSparseAttnConfig               │
│   register_sparse_attn_backend() · get_sparse_attn_backend()        │
│   SparseAttention(nn.Module) · DiTSparseAttentionAdapter            │
└─────────────────────────────────────────────────────────────────────┘
                                 │ integrates with
┌────────────────────────────────▼────────────────────────────────────┐
│                  vllm-omni model layer                              │
│   WanTransformer3DModel.enable_sparse_attention()                   │
│   Wan22Pipeline (wires od_config.sparse_attn → transformer)         │
│   CLI: --sparse-attn-backend / --sparse-attn-topk / --schedule      │
└─────────────────────────────────────────────────────────────────────┘

1. Sparse pattern descriptor

SparsePatternSpec is a backend-agnostic dataclass describing the sparsity pattern. Each backend's metadata builder translates it to its native format.

@dataclass
class SparsePatternSpec:
    pattern_type: SparsePatternType   # BLOCK_SPARSE | SLIDING_WINDOW |
                                      # SPATIAL_AWARE | DYNAMIC_TOPK |
                                      # ARROW | COMPOSITE | DENSE

    # Block-sparse fields
    block_size_q: int = 128
    block_size_kv: int = 64
    block_mask: Optional[torch.Tensor] = None      # (H, Qb, KVb)
    bsr_indptr: Optional[torch.Tensor] = None
    bsr_indices: Optional[torch.Tensor] = None
    sub_block_mask: Optional[torch.Tensor] = None   # (nnz, Bq, Bkv)

    # Sliding window
    window_size: Optional[int] = None

    # Spatial-aware (DiT patch grid)
    spatial_layout: Optional[SpatialLayout] = None
    spatial_radius: Optional[int] = None

    # Dynamic top-k
    topk_ratio: Optional[float] = None

    # MMDiT arrow attention
    num_image_tokens: Optional[int] = None
    num_text_tokens: Optional[int] = None
    i2i_pattern: Optional[SparsePatternSpec] = None
    cross_dense: bool = True

    # Denoising step context (set by DiTSparseAttentionAdapter)
    current_step: Optional[int] = None
    total_steps: Optional[int] = None

SpatialLayout captures the patch grid geometry:

@dataclass(frozen=True)
class SpatialLayout:
    height: int
    width: int
    frames: int = 1
    patch_size: int = 2
    ordering: Literal["raster", "hilbert", "z_curve"] = "raster"

File: vllm_omni/diffusion/sparse_attn/pattern.py

2. Backend protocol (ABC)

Each backend is a factory class with static methods. The plan → forward lifecycle maps to FlashInfer's plan() / run() and is compatible with SpargeAttn's stateless functional calls.

class SparseAttentionBackend(ABC):
    @staticmethod
    @abstractmethod
    def get_name() -> str: ...

    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type[SparseAttentionImpl]: ...

    @staticmethod
    @abstractmethod
    def get_metadata_builder_cls() -> Type[SparseMetadataBuilder]: ...

    @staticmethod
    @abstractmethod
    def supported_patterns() -> set[SparsePatternType]: ...

    @staticmethod
    @abstractmethod
    def supported_module_types() -> set[AttentionModuleType]: ...

    @staticmethod
    def validate_config(config: "DiffusionSparseAttnConfig") -> None:
        pass  # default: no-op


class SparseAttentionImpl(ABC):
    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,   # (B, S, H, D) layout
        key: torch.Tensor,
        value: torch.Tensor,
        metadata: SparseAttentionMetadata,
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor: ...


class SparseMetadataBuilder(ABC):
    @abstractmethod
    def plan(
        self,
        pattern_spec: SparsePatternSpec,
        module_type: AttentionModuleType,
        num_qo_heads: int,
        num_kv_heads: int,
        head_dim: int,
        seq_len_q: int,
        seq_len_kv: int,
        batch_size: int = 1,
        dtype: torch.dtype = torch.float16,
        device: torch.device = torch.device("cuda"),
    ) -> SparseAttentionMetadata: ...

    def update_step(
        self, metadata: SparseAttentionMetadata,
        step: int, total_steps: int,
    ) -> SparseAttentionMetadata:
        return metadata   # default: no-op

File: vllm_omni/diffusion/sparse_attn/base.py

3. Plugin-based backend registry

Backends are discovered via three mechanisms, in priority order:

  1. Entry points — external packages declare in pyproject.toml:

    [project.entry-points."vllm_omni.sparse_attn"]
    flashinfer = "flashinfer_vllm_omni.backend:FlashInferSparseBackend"
  2. Explicit registration — call at import time:

    from vllm_omni.diffusion.sparse_attn.registry import register_sparse_attn_backend
    register_sparse_attn_backend("mybackend", "my_pkg.module:MyBackend")
  3. Full class path — pass directly in config/CLI:

    --sparse-attn-backend "my_pkg.module.MyBackend"

The registry has no hardcoded external backends. The only built-in is DenseFallbackBackend for test correctness comparisons.

def get_sparse_attn_backend(
    sparse_attn_config: Optional[DiffusionSparseAttnConfig],
) -> Optional[Type[SparseAttentionBackend]]:
    """Resolution priority:
    1. DIFFUSION_SPARSE_ATTN_BACKEND env var
    2. sparse_attn_config.backend
    3. Auto-select from installed entry_point backends
    4. None (use default dense Attention layer)
    """

File: vllm_omni/diffusion/sparse_attn/registry.py

4. Configuration

DiffusionSparseAttnConfig mirrors the existing DiffusionCacheConfig pattern:

@dataclass
class DiffusionSparseAttnConfig:
    backend: str = "auto"
    pattern_type: str = "dynamic_topk"
    topk_ratio: float = 0.5
    block_size_q: int = 128
    block_size_kv: int = 64
    window_size: int | None = None
    spatial_radius: int | None = None
    schedule: str = "constant"   # "constant" | "conservative" | "aggressive"

Integrated into OmniDiffusionConfig:

class OmniDiffusionConfig:
    sparse_attn_backend: str | None = None   # shorthand
    sparse_attn: DiffusionSparseAttnConfig | dict | None = None

CLI arguments (in vllm_omni/entrypoints/cli/serve.py):

--sparse-attn-backend   Backend name or full class path
--sparse-attn-topk      Top-k ratio (0-1)
--sparse-attn-schedule  Sparsity schedule: constant | conservative | aggressive

Environment variable fallback: DIFFUSION_SPARSE_ATTN_BACKEND

Files: vllm_omni/diffusion/data.py, vllm_omni/entrypoints/cli/serve.py, vllm_omni/engine/async_omni_engine.py

5. DiT-specific adapter

DiTSparseAttentionAdapter manages denoising-step-aware sparsity scheduling:

class DiTSparseAttentionAdapter:
    def __init__(
        self,
        backend_cls: Type[SparseAttentionBackend],
        module_type: AttentionModuleType,
        num_heads: int, num_kv_heads: int, head_dim: int,
        sparse_attn_config: DiffusionSparseAttnConfig,
    ): ...

    def begin_step(self, step: int, total_steps: int) -> None: ...
    def forward(self, query, key, value) -> torch.Tensor: ...
    def reset(self) -> None: ...

Built-in sparsity schedules:

Schedule Behavior
constant Uses topk_ratio from config at every step
conservative Full attention early (t<0.2), ramp to 70% sparsity
aggressive 80% sparsity early, decrease to 50%

Custom schedules can be registered via register_sparsity_schedule(name, fn).

File: vllm_omni/diffusion/sparse_attn/dit_adapter.py

6. Drop-in SparseAttention module

class SparseAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        num_kv_heads: Optional[int] = None,
        module_type: AttentionModuleType = AttentionModuleType.DIT_SELF,
        sparse_attn_config: Optional[DiffusionSparseAttnConfig] = None,
        dtype: torch.dtype = torch.float16,
    ): ...

    def forward(self, query, key, value) -> torch.Tensor: ...

Metadata is cached by (batch_size, seq_len_q, seq_len_kv, device) key to avoid re-planning across layers with the same shape.

File: vllm_omni/diffusion/sparse_attn/attention.py

7. Wan 2.2 model integration

WanTransformer3DModel.enable_sparse_attention() replaces self-attention layers post-construction:

def enable_sparse_attention(self, sparse_attn_config):
    backend_cls = get_sparse_attn_backend(sparse_attn_config)
    if backend_cls is None:
        return  # no-op: use default dense Attention layer

    for block in self.blocks:
        block.attn1.sparse_attn = SparseAttention(
            num_heads=..., head_dim=...,
            sparse_attn_config=sparse_attn_config,
        )

WanSelfAttention.forward() dispatches:

if self.sparse_attn is not None and attn_mask is None:
    output = self.sparse_attn(query, key, value)
else:
    output = self.attn(query, key, value, attn_metadata)  # existing dense

When attn_mask is present (e.g., SP auto-padding), sparse attention is bypassed with a logged warning.

The pipeline calls enable_sparse_attention() after transformer construction:

# In Wan22Pipeline.__init__():
if od_config.sparse_attn is not None:
    self.transformer.enable_sparse_attention(od_config.sparse_attn)

Files: vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py, vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py

8. File layout

vllm_omni/diffusion/
  sparse_attn/
    __init__.py          # Public API exports
    pattern.py           # SparsePatternSpec, SparsePatternType, SpatialLayout
    base.py              # SparseAttentionBackend, SparseAttentionImpl,
                         # SparseMetadataBuilder, SparseAttentionMetadata,
                         # AttentionModuleType
    registry.py          # Plugin registry: entry_points + register + resolve
    attention.py         # SparseAttention(nn.Module)
    dit_adapter.py       # DiTSparseAttentionAdapter + sparsity schedules
    backends/
      dense.py           # DenseFallbackBackend (test-only correctness oracle)
  data.py                # DiffusionSparseAttnConfig (added)
  models/wan2_2/
    wan2_2_transformer.py  # enable_sparse_attention() (added)
    pipeline_wan2_2.py     # Pipeline wiring (added)

contrib/                   # External plugin packages (reference implementations)
  flashinfer-vllm-omni/    # FlashInfer BSR block-sparse backend
  sparge-vllm-omni/        # SpargeAttn dynamic top-k backend

tests/diffusion/sparse_attn/
  test_protocol.py         # ABC, config, registry tests
  test_config_cli.py       # CLI wiring, env var, priority resolution
  test_dit_adapter.py      # Schedule functions, adapter lifecycle
  test_wan22_sparse.py     # Wan 2.2 model integration (CPU, mocked)
  test_wan22_gpu_integration.py  # GPU tests with FlashInfer plugin

benchmarks/
  sparse_attn_benchmark.py     # Kernel-level micro-benchmark
  wan22_sparse_benchmark.py    # End-to-end Wan 2.2 benchmark

examples/offline_inference/text_to_video/
  wan2_2_sparse.py             # Wan 2.2 sparse attention example

9. Creating a plugin backend

External packages implement SparseAttentionBackend and register via entry_points:

pyproject.toml:

[project.entry-points."vllm_omni.sparse_attn"]
mybackend = "my_package.backend:MyBackendClass"

my_package/backend.py:

from vllm_omni.diffusion.sparse_attn.base import (
    SparseAttentionBackend, SparseAttentionImpl, SparseMetadataBuilder,
)

class MyImpl(SparseAttentionImpl):
    def forward(self, query, key, value, metadata, output=None):
        # Input: (B, S, H, D), implement your kernel here
        ...

class MyBuilder(SparseMetadataBuilder):
    def plan(self, pattern_spec, module_type, num_qo_heads, ...):
        ...

class MyBackendClass(SparseAttentionBackend):
    @staticmethod
    def get_name(): return "mybackend"
    @staticmethod
    def get_impl_cls(): return MyImpl
    @staticmethod
    def get_metadata_builder_cls(): return MyBuilder
    @staticmethod
    def supported_patterns(): return {SparsePatternType.DYNAMIC_TOPK}
    @staticmethod
    def supported_module_types(): return {AttentionModuleType.DIT_SELF}

Usage:

pip install my-sparse-backend
vllm-omni serve Model --omni --sparse-attn-backend mybackend --sparse-attn-topk 0.5

10. Usage examples

CLI (serving):

vllm-omni serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni \
    --sparse-attn-backend spargeattn \
    --sparse-attn-topk 0.5 \
    --sparse-attn-schedule conservative

Offline inference:

from vllm_omni.diffusion.data import DiffusionSparseAttnConfig
from vllm_omni.entrypoints.omni import Omni

omni = Omni(
    model="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
    sparse_attn=DiffusionSparseAttnConfig(
        backend="spargeattn",
        topk_ratio=0.5,
        schedule="conservative",
    ),
)
result = omni.generate({"prompt": "A cat playing piano"}, sampling_params)

Environment variable:

export DIFFUSION_SPARSE_ATTN_BACKEND=spargeattn
# Backend auto-discovered from installed entry_points

Full class path (development):

--sparse-attn-backend "my_dev_pkg.sparse.MyExperimentalBackend"

What does NOT change

  • The AR/LLM backbone attention path (vllm_omni/attention/) is untouched. This RFC covers DiT modules only.
  • Existing full-attention DiT models continue to work unchanged. When no sparse backend is installed, enable_sparse_attention() is a no-op and the existing Attention layer handles dense attention with full SP, ring attention, and platform dispatch.
  • vLLM-Omni does not take a hard dependency on any sparse attention library. All backends are external packages discovered via entry_points.

Alternatives considered

  1. In-tree vendor backends: rejected. Tying kernel implementations to vLLM-Omni means every new kernel requires a core PR and breaks when upstream APIs change. Plugin architecture keeps core clean.
  2. Reusing vLLM's existing AttentionBackend: DiT modules do not use paged KV cache, causal masking, or chunked prefill. The existing attention system is tightly coupled to dense patterns with SP, ring attention, and platform dispatch. A separate sparse_attn module avoids contaminating the existing attention path.
  3. FlashInfer-only implementation: FlashInfer does not support SpargeAttn's dynamic top-k prediction or RainFusion's head classification, which are the primary accuracy-efficiency levers for video DiT workloads.

Benchmark Results

Environment

  • GPU: NVIDIA A100 80GB PCIe
  • Model: Wan-AI/Wan2.2-T2V-A14B-Diffusers (14B parameters)
  • Backend: SpargeAttn (via sparge-vllm-omni entry_point plugin)

End-to-end (480x832, 81 frames, 40 denoising steps)

Run Total Time Per-step Speedup
Dense FA3 (baseline) 785s 19.6s 1.00x
SpargeAttn topk=0.7 763s 19.1s 1.03x
SpargeAttn topk=0.5 677s 16.9s 1.16x
SpargeAttn topk=0.3 590s 14.7s 1.33x

Kernel-level micro-benchmark (attention-only, 40 heads × 128 head_dim)

Config SeqLen Dense FA3 Sparge topk=0.5 Sparge topk=0.3 Sparge topk=0.2
480p 17f 6,630 10.8ms 10.1ms (1.1x) 7.6ms (1.4x) 6.4ms (1.7x)
480p 33f 12,870 39.5ms 31.5ms (1.3x) 22.1ms (1.8x) 17.1ms (2.3x)
720p 33f 29,700 206.1ms 146.7ms (1.4x) 95.6ms (2.2x) 70.0ms (2.9x)

Testing

98 tests (97 passed, 1 skipped when SpargeAttn not installed):

Test File Tests Coverage
test_protocol.py 32 Pattern spec, config, registry, SparseAttention module
test_config_cli.py 19 CLI args, env var, priority resolution, end-to-end flow
test_dit_adapter.py 25 Schedules, adapter lifecycle, GPU forward, metadata caching
test_wan22_sparse.py 7 Wan 2.2 model integration, plugin registration
test_wan22_gpu_integration.py 5+1 FlashInfer GPU tests, SpargeAttn placeholder

Feedback Period

Two weeks from the date this issue is opened.

Key questions for reviewers:

  1. Should the contrib/ plugin packages ship in the vllm-omni repo, or move to separate repositories?
  2. Is the entry_point group name vllm_omni.sparse_attn appropriate, or should it be scoped differently?
  3. Should DiTSparseAttentionAdapter be exposed as a first-class public API, or remain internal?
  4. Should the DenseFallbackBackend be kept for test correctness comparisons, or removed entirely?

CC List.

@wtomin @ZJY0516 @hsliuustc0106 @jiangmengyu18 @gglorian @gcanlin

Any Other Things.

Related issues and PRs:

Related work and references:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions