Skip to content

Optimize Qwen3 RoPE: precompute cos/sin cache for static rope_type#45748

Draft
RobTand wants to merge 1 commit intohuggingface:mainfrom
RobTand:qwen3-rope-precompute-cache
Draft

Optimize Qwen3 RoPE: precompute cos/sin cache for static rope_type#45748
RobTand wants to merge 1 commit intohuggingface:mainfrom
RobTand:qwen3-rope-precompute-cache

Conversation

@RobTand
Copy link
Copy Markdown

@RobTand RobTand commented May 2, 2026

Summary

Precompute the RoPE cos/sin tables for rope_type == "default" at module init and index_select in forward, replacing the per-call inv_freq @ position_ids BMM. Dynamic rope types (NTK-aware scaling, etc.) continue to use the existing per-call path via @dynamic_rope_update.

Motivation

  1. Performance: eliminates a small per-call cuBLAS BMM and an FP32 autocast block on every forward. Most useful at decode where each step pays the latency.
  2. Robustness: the BMM hits a cuBLAS deterministic-mode edge case with bf16 + :4096:8 workspace config that produces NaN. Precompute side-steps it. (Underlying issue is arguably PyTorch/cuBLAS, but precompute is a useful workaround that also helps perf.)
  3. Standard pattern: older RoPE implementations always cached; the dynamic recompute was added specifically to support changing inv_freq for long-context scaling.

What changed

  • src/transformers/models/qwen3/modeling_qwen3.py:
    • Qwen3RotaryEmbedding.__init__ now precomputes cos_cached/sin_cached for max_position_embeddings x head_dim when rope_type == "default", using torch.outer instead of BMM.
    • Qwen3RotaryEmbedding.forward indexes into the cache when available and the requested positions fit the static cache; it falls back to the existing BMM path for dynamic rope types and long-position extrapolation.
    • @dynamic_rope_update decorator preserved so NTK paths get their inv_freq updates.
    • The cache buffers are non-persistent derived buffers, matching existing rotary buffer practice and avoiding checkpoint bloat. They are rebuilt in _init_weights so meta/from_pretrained loading initializes them after inv_freq is restored.

Math equivalence

Verified bit-identical to the existing forward via torch.testing.assert_close(rtol=0, atol=0) on representative position vectors. See new test test_rope_precomputed_cache_matches_legacy_path.

Memory

Adds two FP32 buffers of shape (max_position_embeddings, head_dim) per Qwen3RotaryEmbedding instance. For Qwen3 with max_position=128k and head_dim=128, that's about 128 MB per RoPE module. There's typically one shared rotary instance per model, so the overhead is modest relative to model weights.

Status

DRAFT - opened for early feedback on the approach before broader test/benchmark work and before considering whether to mirror this change to other model families (Llama, Mistral, DeepSeek-V3, Gemma 3, etc., which all use a similar dynamic-RoPE forward pattern).

Test

  • New unit test verifies cached and legacy paths produce bit-identical cos/sin.
  • Existing Qwen3 tests pass: pytest tests/models/qwen3/.

Precompute Qwen3 rotary cos/sin tables at module init for rope_type="default" and use index_select in forward instead of the per-call BMM path.

Keep dynamic and long-position paths on the existing recompute implementation, and rebuild the non-persistent derived cache during weight initialization so meta/from_pretrained loading remains valid.
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 2, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant