[mxfp8 training] update triton_to_mxfp8_dim0 nan handling#4201
[mxfp8 training] update triton_to_mxfp8_dim0 nan handling#4201danielvegamyhre wants to merge 1 commit intomainfrom
Conversation
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
f0b651a to
74a1db5
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4201
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 8 Unrelated FailuresAs of commit 82e7fc9 with merge base 3ad1067 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
74a1db5 to
e813474
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
e813474 to
346fcca
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
346fcca to
8df8fea
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
8df8fea to
b7c40f1
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
b7c40f1 to
46eec1e
Compare
|
edit: nevermind, managed to recover performance |
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
46eec1e to
43a59a1
Compare
cf861ec to
c38b55c
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
c38b55c to
bf5918a
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
bf5918a to
c231e2b
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
c231e2b to
39aae18
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
39aae18 to
8858c92
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
8858c92 to
c4bf474
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
c4bf474 to
5c6b807
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
5c6b807 to
7dfe87c
Compare
…rch reference stack-info: PR: #4201, branch: danielvegamyhre/stack/162
|
@claude review this PR |
|
Claude finished @vkuzo's task in 2m 16s —— View job Code Review: [mxfp8 training] update triton_to_mxfp8_dim0 nan handling
Overall this is a well-motivated change that aligns both the Triton kernel and the torch reference implementation with TransformerEngine's CUDA behavior for handling special values (NaN, inf) during MXFP8 quantization. The approach of matching TE's Issues1. Dim1 kernel still uses division instead of multiplication ( The dim0 kernel ( col_normalized_t_r = x_block_t_r / col_scale_r[:, None]Since 2. Triton The torch reference ( abs_vals = torch.abs(data_hp)
max_abs = torch.amax(
torch.where(torch.isnan(abs_vals), float("-inf"), abs_vals), -1
).unsqueeze(-1)But the Triton kernel at line 483 uses bare In the torch reference, NaN elements are ignored in the max, so the scale is computed from the non-NaN elements, and then individual NaN elements are preserved post-scaling (line 173). These two paths produce different results for blocks containing a mix of NaN and non-NaN values. The tests pass because This might be intentional if the expectation is that NaN never appears alongside valid data in practice (i.e., either the whole block is NaN or none of it is), but it's worth documenting this discrepancy or adding a test case that exercises a mixed NaN/non-NaN block where the behavior difference would be visible. 3. After line 173 re-introduces NaN for NaN input elements, line 175 does Minor Nits4. Typo ( # Calcculate the reciprocal fp32 scaleShould be "Calculate". 5. All-NaN block edge case ( If every element in a block is NaN, 6. Test coverage gap: no dim1 special value tests The new special-value tests only cover dim0. Given the dim1 kernel shares the same 7. Variable naming consistency The rename from SummaryThe core logic for matching TE's | |
Summary
triton_to_mxfp8_dim0)Changes
__hmaxwhich treats nans as missing data and returns the numeric value.Benchmarks