Skip to content

[mxfp8 training] update triton_to_mxfp8_dim0 nan handling#4201

Open
danielvegamyhre wants to merge 1 commit intomainfrom
danielvegamyhre/stack/162
Open

[mxfp8 training] update triton_to_mxfp8_dim0 nan handling#4201
danielvegamyhre wants to merge 1 commit intomainfrom
danielvegamyhre/stack/162

Conversation

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre commented Mar 30, 2026

Summary

  • Received reports of NaN loss with MXFP8 training that resolved when opting out of using Triton kernel for dim0 quantization (triton_to_mxfp8_dim0)
  • Added unit tests with various special values (nan, inf, -inf, subnormals, extremely large/small normal values, etc) to find discrepancies between torch impl and triton
  • To make a long story short, I became suspicious that the torch reference impl was also not handling certain cases correctly. To get triton to match, i had to do special sweeps of NaN first, then inf, then -inf (in that order), which killed perf. All of which the CUDA code doesn't have to do! (suspicious). So I updated both torch and triton to match the same TE RCEIL logic, which doesn't need any of this special handling .

Changes

Benchmarks

(torch) dev@gpu-dev-8951ebdf:~/ao$ CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/home/$USER/ao:$PYTHONPATH python benchmarks/mx_formats/cast_bench.py --mode dim0_mxfp8_triton_rceil --M 32768 --K 7168
M 32768 K 7168 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.11.0+cu130
triton version: 3.6.0
mode: dim0_mxfp8_triton_rceil
time_us 123.9359974861145
mem_bw_gbps 5744.76438195262

danielvegamyhre added a commit that referenced this pull request Mar 30, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from f0b651a to 74a1db5 Compare March 30, 2026 20:14
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4201

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 8 Unrelated Failures

As of commit 82e7fc9 with merge base 3ad1067 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 30, 2026
@danielvegamyhre danielvegamyhre added mx module: training quantize_ api training flow labels Mar 30, 2026
@danielvegamyhre danielvegamyhre marked this pull request as draft March 30, 2026 20:24
danielvegamyhre added a commit that referenced this pull request Mar 30, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from 74a1db5 to e813474 Compare March 30, 2026 20:24
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 30, 2026 20:24
@danielvegamyhre danielvegamyhre marked this pull request as draft March 30, 2026 20:31
danielvegamyhre added a commit that referenced this pull request Mar 30, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from e813474 to 346fcca Compare March 30, 2026 20:31
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 30, 2026 20:31
@danielvegamyhre danielvegamyhre marked this pull request as draft March 30, 2026 22:09
danielvegamyhre added a commit that referenced this pull request Mar 30, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from 346fcca to 8df8fea Compare March 30, 2026 22:09
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 30, 2026 22:09
@danielvegamyhre danielvegamyhre marked this pull request as draft March 30, 2026 23:09
danielvegamyhre added a commit that referenced this pull request Mar 30, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from 8df8fea to b7c40f1 Compare March 30, 2026 23:09
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 30, 2026 23:09
@danielvegamyhre danielvegamyhre marked this pull request as draft March 30, 2026 23:41
danielvegamyhre added a commit that referenced this pull request Mar 30, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from b7c40f1 to 46eec1e Compare March 30, 2026 23:41
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 30, 2026 23:41
@danielvegamyhre
Copy link
Copy Markdown
Contributor Author

danielvegamyhre commented Mar 30, 2026

edit: nevermind, managed to recover performance

@danielvegamyhre danielvegamyhre marked this pull request as draft March 31, 2026 01:16
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from 46eec1e to 43a59a1 Compare March 31, 2026 01:16
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from cf861ec to c38b55c Compare March 31, 2026 02:15
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 31, 2026 02:15
@danielvegamyhre danielvegamyhre marked this pull request as draft March 31, 2026 02:17
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from c38b55c to bf5918a Compare March 31, 2026 02:17
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 31, 2026 02:17
@danielvegamyhre danielvegamyhre marked this pull request as draft March 31, 2026 02:19
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from bf5918a to c231e2b Compare March 31, 2026 02:19
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 31, 2026 02:19
@danielvegamyhre danielvegamyhre marked this pull request as draft March 31, 2026 02:26
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from c231e2b to 39aae18 Compare March 31, 2026 02:26
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 31, 2026 02:27
@danielvegamyhre danielvegamyhre marked this pull request as draft March 31, 2026 02:30
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from 39aae18 to 8858c92 Compare March 31, 2026 02:30
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 31, 2026 02:31
@danielvegamyhre danielvegamyhre marked this pull request as draft March 31, 2026 02:36
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from 8858c92 to c4bf474 Compare March 31, 2026 02:36
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 31, 2026 02:37
@danielvegamyhre danielvegamyhre marked this pull request as draft March 31, 2026 02:51
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from c4bf474 to 5c6b807 Compare March 31, 2026 02:51
danielvegamyhre added a commit that referenced this pull request Mar 31, 2026
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/162 branch from 5c6b807 to 7dfe87c Compare March 31, 2026 02:53
…rch reference

stack-info: PR: #4201, branch: danielvegamyhre/stack/162
@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Mar 31, 2026

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Mar 31, 2026

Claude finished @vkuzo's task in 2m 16s —— View job


Code Review: [mxfp8 training] update triton_to_mxfp8_dim0 nan handling

  • Read changed files and understand the diff
  • Review test changes (test_kernels.py)
  • Review kernel changes (kernels.py)
  • Review mx_tensor.py changes
  • Post review feedback

Overall this is a well-motivated change that aligns both the Triton kernel and the torch reference implementation with TransformerEngine's CUDA behavior for handling special values (NaN, inf) during MXFP8 quantization. The approach of matching TE's float_to_e8m0 and exp2f_rcp semantics is sound. A few items to consider:

Issues

1. Dim1 kernel still uses division instead of multiplication (kernels.py:716)

The dim0 kernel (to_mxfp8_dim0_kernel) was updated to use x_block_r * descale_broadcast (line 823), but the dim1 kernel (to_mxfp8_dim1_kernel) at line 716 still uses division:

col_normalized_t_r = x_block_t_r / col_scale_r[:, None]

Since _triton_calculate_scale_rceil now returns a reciprocal (descale_fp) and _triton_calculate_scale_floor was also changed to return a reciprocal (line 594: descale_fp = 1.0 / scale_fp), the dim1 kernel should be using multiplication too. Dividing by a reciprocal would produce x * scale instead of x / scale, which would give wrong results. This looks like a bug — the dim1 kernel's scaling is now inverted for both RCEIL and FLOOR modes.

Fix this →

2. Triton tl.max NaN propagation vs. torch reference NaN-ignoring (kernels.py:483)

The torch reference (mx_tensor.py:216-219) now explicitly ignores NaN when computing max_abs:

abs_vals = torch.abs(data_hp)
max_abs = torch.amax(
    torch.where(torch.isnan(abs_vals), float("-inf"), abs_vals), -1
).unsqueeze(-1)

But the Triton kernel at line 483 uses bare tl.max(x, axis=axis) where x is tl.abs(x_block_r). Since abs(NaN) = NaN and tl.max will propagate NaN, the Triton kernel will produce a NaN max_abs when any element in the block is NaN. This means scale_input will be NaN, is_nan will be true, and the scale will be set to 0xFF (NaN).

In the torch reference, NaN elements are ignored in the max, so the scale is computed from the non-NaN elements, and then individual NaN elements are preserved post-scaling (line 173).

These two paths produce different results for blocks containing a mix of NaN and non-NaN values. The tests pass because test_triton_mxfp8_dim0_special_values only checks that NaN patterns and finite values match between torch and triton — but it's comparing the triton kernel output against the torch reference which calls to_mx() (which now ignores NaN in max). If a block has [nan, 100.0, 1.0, ...], the torch ref would compute scale from 100.0 and quantize the non-NaN values normally, while the triton kernel would set the scale to NaN (0xFF), making all outputs NaN.

This might be intentional if the expectation is that NaN never appears alongside valid data in practice (i.e., either the whole block is NaN or none of it is), but it's worth documenting this discrepancy or adding a test case that exercises a mixed NaN/non-NaN block where the behavior difference would be visible.

3. torch.clamp behavior with NaN (mx_tensor.py:175)

After line 173 re-introduces NaN for NaN input elements, line 175 does torch.clamp(data_lp, min=-max_pos, max=max_pos). In PyTorch, torch.clamp propagates NaN (NaN stays NaN), so this is fine. However, this is a subtle invariant worth a brief comment, since the CUDA __saturatef-style clamping typically maps NaN to 0 in some contexts.

Minor Nits

4. Typo (kernels.py:807)

# Calcculate the reciprocal fp32 scale

Should be "Calculate".

5. All-NaN block edge case (mx_tensor.py:217-219)

If every element in a block is NaN, torch.where(torch.isnan(abs_vals), float("-inf"), abs_vals) produces all -inf, and torch.amax returns -inf. Then descale = -inf / max_pos = -inf, which hits the torch.isinf(descale) branch and sets exponent = 254. This gives rcp_fp32 = 2^-127, and all values are nan * 2^-127 = nan, which seems reasonable. But the scale being 254 (inf) rather than 255 (NaN) might be surprising — TE's CUDA code would get NaN from __hmax on all-NaN input and produce scale 255. This is a minor inconsistency.

6. Test coverage gap: no dim1 special value tests

The new special-value tests only cover dim0. Given the dim1 kernel shares the same _triton_calculate_scale_rceil function and has the division/multiplication issue mentioned above, adding at least a basic special-values test for dim1 would be valuable.

7. Variable naming consistency

The rename from scale_fp to descale_fp in _triton_calculate_scale_rceil is good for clarity. Consider also renaming col_scale_r in the dim1 kernel to col_descale_r (and similarly col_scale_e8m0_r) to match the new semantics, since these now hold reciprocals.

Summary

The core logic for matching TE's float_to_e8m0 + exp2f_rcp pattern is well-structured and well-documented with CUDA reference code in comments. The main concern is the dim1 kernel division bug (issue #1), which appears to invert the scaling. The NaN handling discrepancy between Triton and torch reference (issue #2) is worth addressing or at least documenting.


|

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants