[MoE Training] Add tensorwise FP8 grouped GEMM with fused Triton quantization kernels#4179
Open
alex-minooka wants to merge 3 commits intopytorch:mainfrom
Open
[MoE Training] Add tensorwise FP8 grouped GEMM with fused Triton quantization kernels#4179alex-minooka wants to merge 3 commits intopytorch:mainfrom
alex-minooka wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4179
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d8f2ca5 to
716757b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add tensorwise FP8 scaling path for MoE grouped GEMM training, complementing the existing rowwise path. Includes fused Triton quantization kernels that collapse ~15 sequential ATen kernel launches into 2 passes, eliminating HIP dispatch overhead on AMD ROCm.
+18.3% TPS vs upstream rowwise FP8 on DeepSeek-V3 16B MoE training (8x MI325X).
Motivation
The upstream rowwise FP8 path for MoE grouped GEMMs uses per-row (axiswise) scaling. Tensorwise scaling — one scale per tensor for activations, one scale per expert for weights — is a simpler quantization strategy that avoids the per-row scale computation overhead while maintaining training convergence for MoE architectures like DeepSeek-V3.
A naive PyTorch implementation of tensorwise quantization decomposes each quantize call into ~15 separate ATen kernels:
On AMD ROCm, each HIP kernel dispatch costs ~60μs. With 338 quantize calls per training iteration, the dispatch overhead alone adds ~300ms — nearly 20% of iteration time. Additionally, the BF16→F32 promotion kernel is the single most expensive per-call GPU cost (~350μs for a 98432×2048 tensor), yet is unnecessary since the Triton kernels can operate in BF16 and upcast to FP32 only in registers.
Approach
Tensorwise grouped GEMM (
fp8_tensorwise_grouped_mm.py)Autograd function implementing the tensorwise scaling contract:
Three quantization functions registered as
custom_opto prevent inductor from fusing them into Triton kernels that crash during autotuning on ROCm.Fused Triton quantization kernels
Two-pass design for each quantization path:
nan_to_num + abs + block-local max, thenatomic_maxinto a global/group amax buffernan_to_num + scale + clamp + castBoth kernels stay in BF16 throughout, avoiding the large BF16→F32 copy. This collapses ~15 kernel launches into 2, saving ~720μs of dispatch overhead per call.
fp8_tensorwise_2d.py— flat 1D grid over M×K elements for 2D row-major tensors.fp8_tensorwise_per_group.py— flat 1D grid with in-kernel group lookup via a 16-iteration scan of the (small) offsets array. No Python-side precomputation needed.Config and dispatch (
config.py,utils.py)Added
FP8_TENSORWISErecipe andscaling_granularityfield toFloat8TrainingOpConfig. WhenScalingGranularity.TENSORWISEis configured, the dispatch path routes to the tensorwise grouped GEMM implementation.Bug fixes (float8/)
Independent fixes for pre-existing bugs discovered during development:
config.py:292— Fix validation checkingcc1twice instead ofcc1andcc2float8_ops.py:176-179— Fix==(comparison) instead of=(assignment) infloat8_transposeaxiswise dim updatefloat8_ops.py:483— Addis_contiguous()guard before unconditional.contiguous()inallgather_fp8float8_ops.py:543— Fix wrong first argument to_assert_tensorwise_scaleinindex_put_fp8Performance
Benchmarked on 8x MI325X, DeepSeek-V3 16B MoE, seq_len=4096, batch=8/GPU, 8 EP, Torch Compile Enabled:
Loss trajectory is healthy across all configurations, converging from 12.0 to ~4.9 over 100 steps.
Test plan
test/float8/test_float8_utils.py— 8 passedtest/float8/test_base.py— 64 passed (covers transpose, allgather, index_put fixes)test/float8/test_compile.py— 4 passed (compile/eager numeric parity)test/prototype/moe_training/test_fp8_grouped_mm.py— 8 passed (rowwise path unaffected)🤖 Generated with Claude Code