Optimize Qwen3 RoPE: precompute cos/sin cache for static rope_type#45748
Draft
RobTand wants to merge 1 commit intohuggingface:mainfrom
Draft
Optimize Qwen3 RoPE: precompute cos/sin cache for static rope_type#45748RobTand wants to merge 1 commit intohuggingface:mainfrom
RobTand wants to merge 1 commit intohuggingface:mainfrom
Conversation
Precompute Qwen3 rotary cos/sin tables at module init for rope_type="default" and use index_select in forward instead of the per-call BMM path. Keep dynamic and long-position paths on the existing recompute implementation, and rebuild the non-persistent derived cache during weight initialization so meta/from_pretrained loading remains valid.
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Precompute the RoPE cos/sin tables for
rope_type == "default"at module init andindex_selectinforward, replacing the per-callinv_freq @ position_idsBMM. Dynamic rope types (NTK-aware scaling, etc.) continue to use the existing per-call path via@dynamic_rope_update.Motivation
:4096:8workspace config that produces NaN. Precompute side-steps it. (Underlying issue is arguably PyTorch/cuBLAS, but precompute is a useful workaround that also helps perf.)inv_freqfor long-context scaling.What changed
src/transformers/models/qwen3/modeling_qwen3.py:Qwen3RotaryEmbedding.__init__now precomputescos_cached/sin_cachedformax_position_embeddingsxhead_dimwhenrope_type == "default", usingtorch.outerinstead of BMM.Qwen3RotaryEmbedding.forwardindexes into the cache when available and the requested positions fit the static cache; it falls back to the existing BMM path for dynamic rope types and long-position extrapolation.@dynamic_rope_updatedecorator preserved so NTK paths get their inv_freq updates._init_weightsso meta/from_pretrained loading initializes them afterinv_freqis restored.Math equivalence
Verified bit-identical to the existing forward via
torch.testing.assert_close(rtol=0, atol=0)on representative position vectors. See new testtest_rope_precomputed_cache_matches_legacy_path.Memory
Adds two FP32 buffers of shape
(max_position_embeddings, head_dim)per Qwen3RotaryEmbedding instance. For Qwen3 with max_position=128k and head_dim=128, that's about 128 MB per RoPE module. There's typically one shared rotary instance per model, so the overhead is modest relative to model weights.Status
DRAFT - opened for early feedback on the approach before broader test/benchmark work and before considering whether to mirror this change to other model families (Llama, Mistral, DeepSeek-V3, Gemma 3, etc., which all use a similar dynamic-RoPE forward pattern).
Test
pytest tests/models/qwen3/.