Skip to content

Commit f0b651a

Browse files
[mxfp8 training] triton_to_mxfp8_dim0 nan handling consistent with torch reference
1 parent ce07646 commit f0b651a

File tree

2 files changed

+390
-10
lines changed

2 files changed

+390
-10
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,374 @@ def test_cuda_mx_dim0_not_supported():
625625
rowwise=True,
626626
colwise=False,
627627
)
628+
629+
630+
# Additional comprehensive tests for triton_to_mxfp8_dim0 to debug NaN issues
631+
632+
633+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
634+
@pytest.mark.skipif(
635+
not is_sm_at_least_100() and not is_MI350(),
636+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
637+
)
638+
@pytest.mark.parametrize(
639+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
640+
)
641+
def test_triton_mxfp8_dim0_special_values(scaling_mode):
642+
"""Test with special IEEE 754 values that commonly cause NaN issues."""
643+
# Create tensor with special values - make it compatible with block_size=32
644+
block_size = 32
645+
special_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
646+
647+
# Fill first few elements of each row with special values
648+
special_vals[0, :4] = torch.tensor(
649+
[float("inf"), -float("inf"), float("nan"), 0.0], dtype=torch.bfloat16
650+
)
651+
special_vals[1, :4] = torch.tensor([1.0, -1.0, 2.0, -2.0], dtype=torch.bfloat16)
652+
special_vals[2, :4] = torch.tensor(
653+
[1e10, -1e10, 1e-10, -1e-10], dtype=torch.bfloat16
654+
)
655+
special_vals[3, :4] = torch.tensor(
656+
[
657+
torch.finfo(torch.float32).max,
658+
torch.finfo(torch.float32).min,
659+
torch.finfo(torch.float32).tiny,
660+
-torch.finfo(torch.float32).tiny,
661+
],
662+
dtype=torch.bfloat16,
663+
)
664+
665+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
666+
special_vals, block_size=block_size, scaling_mode=scaling_mode
667+
)
668+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
669+
special_vals,
670+
inner_block_size=block_size,
671+
scaling_mode=scaling_mode.value.lower(),
672+
)
673+
x_mx_t = x_mx_t.to(torch.float32)
674+
x_s_t = x_s_t.to(torch.uint8)
675+
x_mx_ref = x_mx_ref.to(torch.float32)
676+
x_s_ref = x_s_ref.to(torch.uint8)
677+
678+
# Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
679+
input_has_nan = special_vals.isnan().any()
680+
if not input_has_nan:
681+
assert not x_mx_t.isnan().any(), (
682+
"quantized tensor should not contain NaNs when input has no NaNs"
683+
)
684+
assert not x_s_t.isnan().any(), (
685+
"scales should not contain NaNs when input has no NaNs"
686+
)
687+
688+
# Compare outputs where both are finite
689+
finite_mask = torch.isfinite(x_mx_ref) & torch.isfinite(x_mx_t)
690+
if finite_mask.any():
691+
torch.testing.assert_close(
692+
x_mx_t[finite_mask], x_mx_ref[finite_mask], rtol=0, atol=0
693+
)
694+
695+
scale_finite_mask = torch.isfinite(x_s_ref) & torch.isfinite(x_s_t)
696+
if scale_finite_mask.any():
697+
torch.testing.assert_close(
698+
x_s_t[scale_finite_mask], x_s_ref[scale_finite_mask], rtol=0, atol=0
699+
)
700+
701+
702+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
703+
@pytest.mark.skipif(
704+
not is_sm_at_least_100() and not is_MI350(),
705+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
706+
)
707+
@pytest.mark.parametrize(
708+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
709+
)
710+
def test_triton_mxfp8_dim0_overflow_underflow(scaling_mode):
711+
"""Test with values near overflow and underflow thresholds."""
712+
# Values near float8_e4m3fn limits
713+
f8_max = torch.finfo(torch.float8_e4m3fn).max # ~448
714+
f8_min = torch.finfo(torch.float8_e4m3fn).tiny # ~1.95e-06
715+
block_size = 32
716+
717+
overflow_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
718+
719+
# Fill first few elements of each row with overflow/underflow values
720+
overflow_vals[0, :4] = torch.tensor(
721+
[f8_max * 0.9, f8_max * 1.1, f8_max * 2.0, f8_max * 10.0], dtype=torch.bfloat16
722+
)
723+
overflow_vals[1, :4] = torch.tensor(
724+
[-f8_max * 0.9, -f8_max * 1.1, -f8_max * 2.0, -f8_max * 10.0],
725+
dtype=torch.bfloat16,
726+
)
727+
overflow_vals[2, :4] = torch.tensor(
728+
[f8_min * 0.1, f8_min * 0.5, f8_min * 2.0, f8_min * 10.0], dtype=torch.bfloat16
729+
)
730+
overflow_vals[3, :4] = torch.tensor(
731+
[-f8_min * 0.1, -f8_min * 0.5, -f8_min * 2.0, -f8_min * 10.0],
732+
dtype=torch.bfloat16,
733+
)
734+
735+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
736+
overflow_vals, block_size=block_size, scaling_mode=scaling_mode
737+
)
738+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
739+
overflow_vals,
740+
inner_block_size=block_size,
741+
scaling_mode=scaling_mode.value.lower(),
742+
)
743+
744+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
745+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
746+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
747+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
748+
749+
750+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
751+
@pytest.mark.skipif(
752+
not is_sm_at_least_100() and not is_MI350(),
753+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
754+
)
755+
@pytest.mark.parametrize(
756+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
757+
)
758+
def test_triton_mxfp8_dim0_extreme_range(scaling_mode):
759+
"""Test with tensors containing both very large and very small values."""
760+
# Mix of extreme values in same tensor to test scaling edge cases
761+
block_size = 32
762+
extreme_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
763+
764+
# Fill first few elements with extreme values
765+
extreme_vals[0, :4] = torch.tensor([1e30, 1e-30, 1e20, 1e-20], dtype=torch.bfloat16)
766+
extreme_vals[1, :4] = torch.tensor(
767+
[-1e30, -1e-30, -1e20, -1e-20], dtype=torch.bfloat16
768+
)
769+
extreme_vals[2, :4] = torch.tensor(
770+
[torch.finfo(torch.float32).max, torch.finfo(torch.float32).tiny, 1.0, -1.0],
771+
dtype=torch.bfloat16,
772+
)
773+
extreme_vals[3, :4] = torch.tensor([0.0, 1e-40, 1e40, -1e40], dtype=torch.bfloat16)
774+
775+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
776+
extreme_vals, block_size=block_size, scaling_mode=scaling_mode
777+
)
778+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
779+
extreme_vals,
780+
inner_block_size=block_size,
781+
scaling_mode=scaling_mode.value.lower(),
782+
)
783+
784+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
785+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
786+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
787+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
788+
789+
790+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
791+
@pytest.mark.skipif(
792+
not is_sm_at_least_100() and not is_MI350(),
793+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
794+
)
795+
@pytest.mark.parametrize("block_size", (1, 2, 4, 8, 16, 32, 64))
796+
@pytest.mark.parametrize(
797+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
798+
)
799+
def test_triton_mxfp8_dim0_edge_block_sizes(block_size, scaling_mode):
800+
"""Test with various block sizes that might expose edge cases."""
801+
# Use size that's divisible by block_size to avoid padding edge cases first
802+
M = max(64, block_size * 2)
803+
K = max(64, block_size * 4)
804+
805+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
806+
807+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
808+
x, block_size=block_size, scaling_mode=scaling_mode
809+
)
810+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
811+
x,
812+
inner_block_size=block_size,
813+
scaling_mode=scaling_mode.value.lower(),
814+
)
815+
816+
assert not x_mx_t.isnan().any(), (
817+
f"quantized tensor should not contain NaNs with block_size={block_size}"
818+
)
819+
assert not x_s_t.isnan().any(), (
820+
f"scales should not contain NaNs with block_size={block_size}"
821+
)
822+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
823+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
824+
825+
826+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
827+
@pytest.mark.skipif(
828+
not is_sm_at_least_100() and not is_MI350(),
829+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
830+
)
831+
@pytest.mark.parametrize(
832+
"shape", [(1, 32), (32, 1), (1, 1), (7, 13), (31, 17), (33, 31)]
833+
)
834+
@pytest.mark.parametrize(
835+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
836+
)
837+
def test_triton_mxfp8_dim0_odd_shapes(shape, scaling_mode):
838+
"""Test with odd tensor shapes that might not align well with block sizes."""
839+
M, K = shape
840+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
841+
block_size = min(32, K) # Adjust block size for small tensors
842+
843+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
844+
x, block_size=block_size, scaling_mode=scaling_mode
845+
)
846+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
847+
x,
848+
inner_block_size=block_size,
849+
scaling_mode=scaling_mode.value.lower(),
850+
)
851+
852+
assert not x_mx_t.isnan().any(), (
853+
f"quantized tensor should not contain NaNs with shape={shape}"
854+
)
855+
assert not x_s_t.isnan().any(), f"scales should not contain NaNs with shape={shape}"
856+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
857+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
858+
859+
860+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
861+
@pytest.mark.skipif(
862+
not is_sm_at_least_100() and not is_MI350(),
863+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
864+
)
865+
@pytest.mark.parametrize(
866+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
867+
)
868+
def test_triton_mxfp8_dim0_denormals_subnormals(scaling_mode):
869+
"""Test with denormal/subnormal values that might cause precision issues."""
870+
# Create values in the denormal range
871+
bf16_tiny = torch.finfo(torch.bfloat16).tiny
872+
f32_tiny = torch.finfo(torch.float32).tiny
873+
block_size = 32
874+
875+
denormal_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
876+
877+
# Fill first few elements with denormal values
878+
denormal_vals[0, :4] = torch.tensor(
879+
[bf16_tiny, bf16_tiny * 0.5, bf16_tiny * 0.1, bf16_tiny * 2.0],
880+
dtype=torch.bfloat16,
881+
)
882+
denormal_vals[1, :4] = torch.tensor(
883+
[f32_tiny, f32_tiny * 0.5, f32_tiny * 0.1, f32_tiny * 2.0], dtype=torch.bfloat16
884+
)
885+
denormal_vals[2, :4] = torch.tensor(
886+
[-bf16_tiny, -bf16_tiny * 0.5, -bf16_tiny * 0.1, -bf16_tiny * 2.0],
887+
dtype=torch.bfloat16,
888+
)
889+
denormal_vals[3, :4] = torch.tensor(
890+
[1e-40, 1e-38, 1e-36, 1e-34], dtype=torch.bfloat16
891+
) # Very small values
892+
893+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
894+
denormal_vals, block_size=block_size, scaling_mode=scaling_mode
895+
)
896+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
897+
denormal_vals,
898+
inner_block_size=block_size,
899+
scaling_mode=scaling_mode.value.lower(),
900+
)
901+
902+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
903+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
904+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
905+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
906+
907+
908+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
909+
@pytest.mark.skipif(
910+
not is_sm_at_least_100() and not is_MI350(),
911+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
912+
)
913+
@pytest.mark.parametrize(
914+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
915+
)
916+
def test_triton_mxfp8_dim0_constant_values(scaling_mode):
917+
"""Test with tensors of constant values to check scale calculation edge cases."""
918+
test_values = [
919+
1.0,
920+
-1.0,
921+
0.5,
922+
-0.5,
923+
2.0,
924+
-2.0,
925+
100.0,
926+
-100.0,
927+
0.01,
928+
-0.01,
929+
torch.finfo(torch.float8_e4m3fn).max,
930+
-torch.finfo(torch.float8_e4m3fn).max,
931+
torch.finfo(torch.float8_e4m3fn).tiny,
932+
-torch.finfo(torch.float8_e4m3fn).tiny,
933+
]
934+
935+
for val in test_values:
936+
x = torch.full((64, 128), val, dtype=torch.bfloat16, device="cuda")
937+
938+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
939+
x, block_size=32, scaling_mode=scaling_mode
940+
)
941+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
942+
x,
943+
inner_block_size=32,
944+
scaling_mode=scaling_mode.value.lower(),
945+
)
946+
947+
assert not x_mx_t.isnan().any(), (
948+
f"quantized tensor should not contain NaNs for constant value {val}"
949+
)
950+
assert not x_s_t.isnan().any(), (
951+
f"scales should not contain NaNs for constant value {val}"
952+
)
953+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
954+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
955+
956+
957+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
958+
@pytest.mark.skipif(
959+
not is_sm_at_least_100() and not is_MI350(),
960+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
961+
)
962+
@pytest.mark.parametrize(
963+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
964+
)
965+
def test_triton_mxfp8_dim0_alternating_signs(scaling_mode):
966+
"""Test with alternating positive/negative patterns that might cause scaling issues."""
967+
M, K = 64, 128
968+
969+
# Create alternating positive/negative pattern
970+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").abs()
971+
x[::2] *= -1 # Make every other row negative
972+
973+
# Also test checkerboard pattern
974+
checkerboard = torch.ones(M, K, dtype=torch.bfloat16, device="cuda")
975+
checkerboard[::2, ::2] *= -1
976+
checkerboard[1::2, 1::2] *= -1
977+
x_checkerboard = (
978+
torch.randn(M, K, dtype=torch.bfloat16, device="cuda").abs() * checkerboard
979+
)
980+
981+
for x_test, name in [(x, "alternating_rows"), (x_checkerboard, "checkerboard")]:
982+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
983+
x_test, block_size=32, scaling_mode=scaling_mode
984+
)
985+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
986+
x_test,
987+
inner_block_size=32,
988+
scaling_mode=scaling_mode.value.lower(),
989+
)
990+
991+
assert not x_mx_t.isnan().any(), (
992+
f"quantized tensor should not contain NaNs for {name} pattern"
993+
)
994+
assert not x_s_t.isnan().any(), (
995+
f"scales should not contain NaNs for {name} pattern"
996+
)
997+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
998+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

0 commit comments

Comments
 (0)