Skip to content

Commit c38b55c

Browse files
[mxfp8 training] triton_to_mxfp8_dim0 nan handling consistent with torch reference
stack-info: PR: #4201, branch: danielvegamyhre/stack/162
1 parent ce07646 commit c38b55c

File tree

3 files changed

+322
-63
lines changed

3 files changed

+322
-63
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 213 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,7 @@ def test_triton_mxfp8_dim1_randn(M, K, scaling_mode):
472472
)
473473
@pytest.mark.parametrize("M", (128, 256))
474474
@pytest.mark.parametrize("K", (128, 256))
475-
@pytest.mark.parametrize(
476-
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
477-
)
475+
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
478476
def test_triton_mxfp8_dim0_randn(M, K, scaling_mode):
479477
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
480478
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
@@ -625,3 +623,215 @@ def test_cuda_mx_dim0_not_supported():
625623
rowwise=True,
626624
colwise=False,
627625
)
626+
627+
628+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
629+
@pytest.mark.skipif(
630+
not is_sm_at_least_100() and not is_MI350(),
631+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
632+
)
633+
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
634+
def test_triton_mxfp8_dim0_special_values(scaling_mode: ScaleCalculationMode):
635+
# Create tensor with special values - make it compatible with block_size=32
636+
block_size = 32
637+
special_vals = torch.zeros(2, block_size, dtype=torch.bfloat16, device="cuda")
638+
639+
# Fill first few elements of each row with special values
640+
special_vals[0, :4] = torch.tensor(
641+
[float("inf"), -float("inf"), float("nan"), 0.0], dtype=torch.bfloat16
642+
)
643+
special_vals[1, :4] = torch.tensor(
644+
[
645+
torch.finfo(torch.float32).max,
646+
torch.finfo(torch.float32).min,
647+
torch.finfo(torch.float32).tiny,
648+
-torch.finfo(torch.float32).tiny,
649+
],
650+
dtype=torch.bfloat16,
651+
)
652+
653+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
654+
special_vals, block_size=block_size, scaling_mode=scaling_mode
655+
)
656+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
657+
special_vals,
658+
inner_block_size=block_size,
659+
scaling_mode=scaling_mode.value.lower(),
660+
)
661+
x_mx_t = x_mx_t.to(torch.float32)
662+
x_s_t = x_s_t.to(torch.uint8)
663+
x_mx_ref = x_mx_ref.to(torch.float32)
664+
x_s_ref = x_s_ref.to(torch.uint8)
665+
666+
# Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
667+
input_has_nan = special_vals.isnan().any()
668+
if not input_has_nan:
669+
assert not x_mx_t.isnan().any(), (
670+
"quantized tensor should not contain NaNs when input has no NaNs"
671+
)
672+
assert not x_s_t.isnan().any(), (
673+
"scales should not contain NaNs when input has no NaNs"
674+
)
675+
676+
# Use NaN-aware comparison to handle nan != nan case properly
677+
# Check NaN patterns match
678+
nan_ref = torch.isnan(x_mx_ref)
679+
nan_triton = torch.isnan(x_mx_t)
680+
assert torch.equal(nan_ref, nan_triton), (
681+
"NaN pattern mismatch between reference and triton"
682+
)
683+
684+
# Check finite values
685+
finite_mask = torch.isfinite(x_mx_ref) & torch.isfinite(x_mx_t)
686+
if finite_mask.any():
687+
assert torch.equal(x_mx_ref[finite_mask], x_mx_t[finite_mask]), (
688+
"Finite values mismatch"
689+
)
690+
691+
# Check infinity patterns
692+
inf_ref = torch.isinf(x_mx_ref)
693+
inf_triton = torch.isinf(x_mx_t)
694+
assert torch.equal(inf_ref, inf_triton), (
695+
"Infinity pattern mismatch between reference and triton"
696+
)
697+
if inf_ref.any():
698+
assert torch.equal(x_mx_ref[inf_ref], x_mx_t[inf_ref]), (
699+
"Infinity values mismatch"
700+
)
701+
702+
# Check scales using exact comparison
703+
x_s_ref_uint8 = x_s_ref.to(torch.uint8)
704+
x_s_t_uint8 = x_s_t.to(torch.uint8)
705+
assert torch.equal(x_s_t_uint8, x_s_ref_uint8), (
706+
"Scale values mismatch between reference and triton"
707+
)
708+
709+
710+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
711+
@pytest.mark.skipif(
712+
not is_sm_at_least_100() and not is_MI350(),
713+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
714+
)
715+
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
716+
def test_triton_mxfp8_dim0_overflow_underflow(scaling_mode):
717+
"""Test with values near overflow and underflow thresholds."""
718+
# Values near float8_e4m3fn limits
719+
f8_max = torch.finfo(torch.float8_e4m3fn).max # ~448
720+
f8_min = torch.finfo(torch.float8_e4m3fn).tiny # ~1.95e-06
721+
block_size = 32
722+
723+
overflow_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
724+
725+
# Fill first few elements of each row with overflow/underflow values
726+
overflow_vals[0, :4] = torch.tensor(
727+
[f8_max * 0.9, f8_max * 1.1, f8_max * 2.0, f8_max * 10.0], dtype=torch.bfloat16
728+
)
729+
overflow_vals[1, :4] = torch.tensor(
730+
[-f8_max * 0.9, -f8_max * 1.1, -f8_max * 2.0, -f8_max * 10.0],
731+
dtype=torch.bfloat16,
732+
)
733+
overflow_vals[2, :4] = torch.tensor(
734+
[f8_min * 0.1, f8_min * 0.5, f8_min * 2.0, f8_min * 10.0], dtype=torch.bfloat16
735+
)
736+
overflow_vals[3, :4] = torch.tensor(
737+
[-f8_min * 0.1, -f8_min * 0.5, -f8_min * 2.0, -f8_min * 10.0],
738+
dtype=torch.bfloat16,
739+
)
740+
741+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
742+
overflow_vals, block_size=block_size, scaling_mode=scaling_mode
743+
)
744+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
745+
overflow_vals,
746+
inner_block_size=block_size,
747+
scaling_mode=scaling_mode.value.lower(),
748+
)
749+
750+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
751+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
752+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
753+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
754+
755+
756+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
757+
@pytest.mark.skipif(
758+
not is_sm_at_least_100() and not is_MI350(),
759+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
760+
)
761+
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
762+
def test_triton_mxfp8_dim0_extreme_range(scaling_mode):
763+
"""Test with tensors containing both very large and very small values."""
764+
# Mix of extreme values in same tensor to test scaling edge cases
765+
block_size = 32
766+
extreme_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
767+
768+
# Fill first few elements with extreme values
769+
extreme_vals[0, :4] = torch.tensor([1e30, 1e-30, 1e20, 1e-20], dtype=torch.bfloat16)
770+
extreme_vals[1, :4] = torch.tensor(
771+
[-1e30, -1e-30, -1e20, -1e-20], dtype=torch.bfloat16
772+
)
773+
extreme_vals[2, :4] = torch.tensor(
774+
[torch.finfo(torch.float32).max, torch.finfo(torch.float32).tiny, 1.0, -1.0],
775+
dtype=torch.bfloat16,
776+
)
777+
extreme_vals[3, :4] = torch.tensor([0.0, 1e-40, 1e40, -1e40], dtype=torch.bfloat16)
778+
779+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
780+
extreme_vals, block_size=block_size, scaling_mode=scaling_mode
781+
)
782+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
783+
extreme_vals,
784+
inner_block_size=block_size,
785+
scaling_mode=scaling_mode.value.lower(),
786+
)
787+
788+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
789+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
790+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
791+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
792+
793+
794+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
795+
@pytest.mark.skipif(
796+
not is_sm_at_least_100() and not is_MI350(),
797+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
798+
)
799+
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
800+
def test_triton_mxfp8_dim0_denormals_subnormals(scaling_mode):
801+
"""Test with denormal/subnormal values that might cause precision issues."""
802+
# Create values in the denormal range
803+
bf16_tiny = torch.finfo(torch.bfloat16).tiny
804+
f32_tiny = torch.finfo(torch.float32).tiny
805+
block_size = 32
806+
807+
denormal_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
808+
809+
# Fill first few elements with denormal values
810+
denormal_vals[0, :4] = torch.tensor(
811+
[bf16_tiny, bf16_tiny * 0.5, bf16_tiny * 0.1, bf16_tiny * 2.0],
812+
dtype=torch.bfloat16,
813+
)
814+
denormal_vals[1, :4] = torch.tensor(
815+
[f32_tiny, f32_tiny * 0.5, f32_tiny * 0.1, f32_tiny * 2.0], dtype=torch.bfloat16
816+
)
817+
denormal_vals[2, :4] = torch.tensor(
818+
[-bf16_tiny, -bf16_tiny * 0.5, -bf16_tiny * 0.1, -bf16_tiny * 2.0],
819+
dtype=torch.bfloat16,
820+
)
821+
denormal_vals[3, :4] = torch.tensor(
822+
[1e-40, 1e-38, 1e-36, 1e-34], dtype=torch.bfloat16
823+
) # Very small values
824+
825+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
826+
denormal_vals, block_size=block_size, scaling_mode=scaling_mode
827+
)
828+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
829+
denormal_vals,
830+
inner_block_size=block_size,
831+
scaling_mode=scaling_mode.value.lower(),
832+
)
833+
834+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
835+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
836+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
837+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

torchao/prototype/mx_formats/kernels.py

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -472,27 +472,29 @@ def triton_mxfp8_dequant_dim0(
472472

473473
@triton.jit
474474
def _triton_calculate_scale_rceil(x, axis, USE_PTX: tl.constexpr):
475+
"""
476+
Calculates and returns reciprocal scale using RCEIL rounding mode
477+
"""
475478
# There is no good support for accessing globals from a jit'ed triton
476479
# function, so we redefine them here. Since this is prototype code which
477480
# we plan to remove after torch.compile catches up, this is fine.
478481
e8m0_exponent_bias = 127
479-
fp32_mbits = 23
480482

481483
# Find the maximum absolute value for each row
482484
max_abs = tl.max(x, axis=axis)
483485

484486
F8E4M3_MAX_RCP: tl.constexpr = 1.0 / 448.0
485487

488+
# Calculate scale input like CUDA: amax * max_norm_rcp
489+
scale_input = max_abs * F8E4M3_MAX_RCP
490+
491+
# Handle special values at scale calculation level (like CUDA float_to_e8m0)
492+
# Ref: https://github.com/NVIDIA/TransformerEngine/blob/b7598aa887eb7d619d64c90692980009669379bf/transformer_engine/common/util/ptx.cuh#L332-L341
493+
is_nan = scale_input != scale_input # NaN check
494+
is_inf = tl.abs(scale_input) == float("inf") # Inf check
495+
486496
if USE_PTX:
487-
# RCEIL scaling mode using PTX instruction supported on sm100.
488-
# The input should be: amax / 448.0
489-
# where 448.0 is the max representable value in FP8 E4M3 format.
490-
scale_input = max_abs.to(tl.float32) * F8E4M3_MAX_RCP
491-
492-
# The PTX instruction outputs a packed uint16 where:
493-
# - high byte = E8M0 of first input (0.0 in our case)
494-
# - low byte = E8M0 of second input (scale_input)
495-
# Casting uint16 to uint8 naturally truncates to the low byte.
497+
# Use PTX instruction for normal values
496498
scale_e8m0_biased = tl.inline_asm_elementwise(
497499
asm="cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;",
498500
constraints="=h,r",
@@ -502,35 +504,53 @@ def _triton_calculate_scale_rceil(x, axis, USE_PTX: tl.constexpr):
502504
pack=1,
503505
).to(tl.uint8)
504506
else:
505-
# Original recil implementation described in https://docs.nvidia.com/cuda/cublas/#d-block-quantization
506-
descale = max_abs * F8E4M3_MAX_RCP
507-
508-
# Clamp to exponents that can be represented in e8m0
507+
# Fallback implementation
509508
scale_e8m0_unbiased = tl.clamp(
510-
tl.ceil(tl.log2(descale)),
509+
tl.ceil(tl.log2(scale_input)),
511510
min=-1 * e8m0_exponent_bias,
512511
max=e8m0_exponent_bias,
513512
)
513+
scale_e8m0_biased = (scale_e8m0_unbiased + 127).to(tl.uint8)
514514

515-
# Create the biased e8m0 representation and cast it to 8 bits
516-
# Set NaN values to 0xFF
517-
is_nan = descale != descale
518-
scale_e8m0_biased = tl.where(is_nan, 0xFF, scale_e8m0_unbiased + 127)
519-
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
515+
# Apply special value overrides (like CUDA)
516+
# Ref: https://github.com/NVIDIA/TransformerEngine/blob/b7598aa887eb7d619d64c90692980009669379bf/transformer_engine/common/util/ptx.cuh#L332-L341
517+
scale_e8m0_biased = tl.where(is_nan, 255, scale_e8m0_biased) # 0xFF for NaN
518+
scale_e8m0_biased = tl.where(is_inf, 254, scale_e8m0_biased) # 0xFE for inf
520519

521-
# TODO(future PR): add NaN handling here,
522-
# https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
523-
# get proper NaN propagation working
524-
# Calculate the scale in floating point.
525-
scale_fp = (scale_e8m0_biased.to(tl.int32) << fp32_mbits).to(
526-
tl.float32, bitcast=True
527-
)
520+
# Efficient reciprocal calculation (like CUDA exp2f_rcp)
521+
FP32_MANTISSA_BITS: tl.constexpr = 23
528522

529-
fp32_exp_bias = 127.0
530-
fp32_min_normal = tl.exp2(-fp32_exp_bias + 1)
531-
scale_fp = tl.clamp(scale_fp, min=fp32_min_normal, max=float("inf"))
523+
# Equivalent CUDA per-thread code is more readable, copying here as documentation:
524+
#
525+
# __device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
526+
# // Handle the special case of NaN.
527+
# if (biased_exp == 255) return __int_as_float(0x7fffffff);
528+
#
529+
# // Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of
530+
# // the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left.
531+
# if (biased_exp == 254) return __int_as_float(0x00400000);
532+
#
533+
# // Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal.
534+
# return __int_as_float((254 - biased_exp) << FP32_MANTISSA_BITS);
535+
# }
536+
descale_fp = tl.where(
537+
scale_e8m0_biased == 255, # NaN case -> return NaN
538+
float("nan"),
539+
tl.where(
540+
scale_e8m0_biased == 254, # Inf case -> return 2^-127
541+
2**-127,
542+
tl.where(
543+
scale_e8m0_biased == 0, # Zero case -> return 1.0 (no scaling)
544+
1.0,
545+
# Normal case: fast bit manipulation (254 - biased_exp) << 23
546+
((254 - scale_e8m0_biased).to(tl.int32) << FP32_MANTISSA_BITS).to(
547+
tl.float32, bitcast=True
548+
),
549+
),
550+
),
551+
)
532552

533-
return scale_fp, scale_e8m0_biased
553+
return descale_fp, scale_e8m0_biased
534554

535555
@triton.jit
536556
def _triton_calculate_scale_floor(
@@ -793,25 +813,23 @@ def to_mxfp8_dim0_kernel(
793813
# Find the maximum absolute value for each row (across columns)
794814
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
795815
if SCALING_MODE == "rceil":
796-
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale_rceil(
816+
descale_fp32_r, scale_e8m0_r = _triton_calculate_scale_rceil(
797817
x_block_abs_r,
798818
axis=1,
799819
USE_PTX=not IS_ROCM,
800820
)
801821
else:
802822
tl.static_assert(SCALING_MODE == "floor")
803-
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale_floor(
823+
descale_fp32_r, scale_e8m0_r = _triton_calculate_scale_floor(
804824
x_block_abs_r,
805825
axis=1,
806826
)
807827

808-
# Divide each row by scale
809-
# Broadcasting scale to match x_block's shape
810-
# x_block_r shape:
811-
# (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE)
812-
# scale[:, None] shape:
813-
# (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
814-
scaled_data_r = x_block_r / scale_fp32_r[:, None]
828+
# Broadcast descale to match x_block's shape
829+
descale_broadcast = descale_fp32_r[:, None]
830+
831+
# Scale the data
832+
scaled_data_r = x_block_r * descale_broadcast
815833

816834
# Reshape back to original tile size
817835
e4m3_data_2d = tl.reshape(scaled_data_r, ROW_TILE_SIZE, COL_TILE_SIZE).to(
@@ -821,8 +839,10 @@ def to_mxfp8_dim0_kernel(
821839
# Store the row-normalized result in row-major format
822840
tl.store(output_ptr + row_major_offsets, e4m3_data_2d, mask=mask)
823841

824-
# Calculate scale offsets to write to
842+
# Store e8m0 scales
825843
scales_per_row = n_cols // SCALE_BLOCK_SIZE
844+
845+
# Calculate scale storage offsets and mask
826846
scale_row_indices = (
827847
pid_row * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
828848
)
@@ -831,9 +851,9 @@ def to_mxfp8_dim0_kernel(
831851
+ tl.arange(0, SCALE_BLOCKS_PER_COL_TILE)[None, :]
832852
)
833853
scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
834-
835-
# Store e8m0 scales
836854
scale_mask = (scale_row_indices < n_rows) & (scale_col_indices < scales_per_row)
855+
856+
# Reshape scale values to 2D and store
837857
scale_e8m0_2d = scale_e8m0_r.reshape(ROW_TILE_SIZE, SCALE_BLOCKS_PER_COL_TILE)
838858
tl.store(scale_ptr + scale_offsets, scale_e8m0_2d, mask=scale_mask)
839859

0 commit comments

Comments
 (0)