Skip to content

[MoE Training] Add tensorwise FP8 grouped GEMM with fused Triton quantization kernels#4179

Open
alex-minooka wants to merge 3 commits intopytorch:mainfrom
alex-minooka:tensorwise-fp8ggemm
Open

[MoE Training] Add tensorwise FP8 grouped GEMM with fused Triton quantization kernels#4179
alex-minooka wants to merge 3 commits intopytorch:mainfrom
alex-minooka:tensorwise-fp8ggemm

Conversation

@alex-minooka
Copy link
Contributor

@alex-minooka alex-minooka commented Mar 26, 2026

Summary

Add tensorwise FP8 scaling path for MoE grouped GEMM training, complementing the existing rowwise path. Includes fused Triton quantization kernels that collapse ~15 sequential ATen kernel launches into 2 passes, eliminating HIP dispatch overhead on AMD ROCm.

+18.3% TPS vs upstream rowwise FP8 on DeepSeek-V3 16B MoE training (8x MI325X).

Motivation

The upstream rowwise FP8 path for MoE grouped GEMMs uses per-row (axiswise) scaling. Tensorwise scaling — one scale per tensor for activations, one scale per expert for weights — is a simpler quantization strategy that avoids the per-row scale computation overhead while maintaining training convergence for MoE architectures like DeepSeek-V3.

A naive PyTorch implementation of tensorwise quantization decomposes each quantize call into ~15 separate ATen kernels:

On AMD ROCm, each HIP kernel dispatch costs ~60μs. With 338 quantize calls per training iteration, the dispatch overhead alone adds ~300ms — nearly 20% of iteration time. Additionally, the BF16→F32 promotion kernel is the single most expensive per-call GPU cost (~350μs for a 98432×2048 tensor), yet is unnecessary since the Triton kernels can operate in BF16 and upcast to FP32 only in registers.

Approach

Tensorwise grouped GEMM (fp8_tensorwise_grouped_mm.py)

Autograd function implementing the tensorwise scaling contract:

  • Forward: one scale for all of A, one scale per expert for B_t
  • Backward (grad_A): one scale for grad_output, one scale per expert for B
  • Backward (grad_B): one scale per group for both grad_output and A

Three quantization functions registered as custom_op to prevent inductor from fusing them into Triton kernels that crash during autotuning on ROCm.

Fused Triton quantization kernels

Two-pass design for each quantization path:

  • Pass 1 (amax): single sequential read of the input, fusing nan_to_num + abs + block-local max, then atomic_max into a global/group amax buffer
  • Pass 2 (quantize): single sequential read of the input, computing scale inline from the amax result, then writing FP8 output with fused nan_to_num + scale + clamp + cast

Both kernels stay in BF16 throughout, avoiding the large BF16→F32 copy. This collapses ~15 kernel launches into 2, saving ~720μs of dispatch overhead per call.

fp8_tensorwise_2d.py — flat 1D grid over M×K elements for 2D row-major tensors.

fp8_tensorwise_per_group.py — flat 1D grid with in-kernel group lookup via a 16-iteration scan of the (small) offsets array. No Python-side precomputation needed.

Config and dispatch (config.py, utils.py)

Added FP8_TENSORWISE recipe and scaling_granularity field to Float8TrainingOpConfig. When ScalingGranularity.TENSORWISE is configured, the dispatch path routes to the tensorwise grouped GEMM implementation.

Bug fixes (float8/)

Independent fixes for pre-existing bugs discovered during development:

  • config.py:292 — Fix validation checking cc1 twice instead of cc1 and cc2
  • float8_ops.py:176-179 — Fix == (comparison) instead of = (assignment) in float8_transpose axiswise dim update
  • float8_ops.py:483 — Add is_contiguous() guard before unconditional .contiguous() in allgather_fp8
  • float8_ops.py:543 — Fix wrong first argument to _assert_tensorwise_scale in index_put_fp8

Performance

Benchmarked on 8x MI325X, DeepSeek-V3 16B MoE, seq_len=4096, batch=8/GPU, 8 EP, Torch Compile Enabled:

Configuration Steady TPS vs Rowwise
BF16 baseline (no FP8) 14,233 +21.6%
FP8 rowwise (upstream) 11,705
FP8 tensorwise (this PR) 13,847 +18.3%

Loss trajectory is healthy across all configurations, converging from 12.0 to ~4.9 over 100 steps.

Test plan

  • Triton quantize kernels produce bitwise-identical FP8 output vs PyTorch ATen fallback
  • test/float8/test_float8_utils.py — 8 passed
  • test/float8/test_base.py — 64 passed (covers transpose, allgather, index_put fixes)
  • test/float8/test_compile.py — 4 passed (compile/eager numeric parity)
  • test/prototype/moe_training/test_fp8_grouped_mm.py — 8 passed (rowwise path unaffected)
  • End-to-end training convergence verified across all FP8 configurations

🤖 Generated with Claude Code

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4179

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2026
@alex-minooka alex-minooka marked this pull request as ready for review March 27, 2026 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant