Skip to content

Commit 8df8fea

Browse files
[mxfp8 training] triton_to_mxfp8_dim0 nan handling consistent with torch reference
stack-info: PR: #4201, branch: danielvegamyhre/stack/162
1 parent ce07646 commit 8df8fea

File tree

2 files changed

+469
-43
lines changed

2 files changed

+469
-43
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 390 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,393 @@ def test_cuda_mx_dim0_not_supported():
625625
rowwise=True,
626626
colwise=False,
627627
)
628+
629+
630+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
631+
@pytest.mark.skipif(
632+
not is_sm_at_least_100() and not is_MI350(),
633+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
634+
)
635+
def test_triton_mxfp8_dim0_special_values():
636+
"""Test with special IEEE 754 values that commonly cause NaN issues."""
637+
# Test only RCEIL mode to match canonical PyTorch behavior
638+
scaling_mode = ScaleCalculationMode.RCEIL
639+
640+
# Create tensor with special values - make it compatible with block_size=32
641+
block_size = 32
642+
special_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
643+
644+
# Fill first few elements of each row with special values
645+
special_vals[0, :4] = torch.tensor(
646+
[float("inf"), -float("inf"), float("nan"), 0.0], dtype=torch.bfloat16
647+
)
648+
special_vals[1, :4] = torch.tensor([1.0, -1.0, 2.0, -2.0], dtype=torch.bfloat16)
649+
special_vals[2, :4] = torch.tensor(
650+
[1e10, -1e10, 1e-10, -1e-10], dtype=torch.bfloat16
651+
)
652+
special_vals[3, :4] = torch.tensor(
653+
[
654+
torch.finfo(torch.float32).max,
655+
torch.finfo(torch.float32).min,
656+
torch.finfo(torch.float32).tiny,
657+
-torch.finfo(torch.float32).tiny,
658+
],
659+
dtype=torch.bfloat16,
660+
)
661+
662+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
663+
special_vals, block_size=block_size, scaling_mode=scaling_mode
664+
)
665+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
666+
special_vals,
667+
inner_block_size=block_size,
668+
scaling_mode=scaling_mode.value.lower(),
669+
)
670+
x_mx_t = x_mx_t.to(torch.float32)
671+
x_s_t = x_s_t.to(torch.uint8)
672+
x_mx_ref = x_mx_ref.to(torch.float32)
673+
x_s_ref = x_s_ref.to(torch.uint8)
674+
675+
# Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
676+
input_has_nan = special_vals.isnan().any()
677+
if not input_has_nan:
678+
assert not x_mx_t.isnan().any(), (
679+
"quantized tensor should not contain NaNs when input has no NaNs"
680+
)
681+
assert not x_s_t.isnan().any(), (
682+
"scales should not contain NaNs when input has no NaNs"
683+
)
684+
685+
# Use NaN-aware comparison to handle nan != nan case properly
686+
687+
# Check NaN patterns match
688+
nan_ref = torch.isnan(x_mx_ref)
689+
nan_triton = torch.isnan(x_mx_t)
690+
assert torch.equal(nan_ref, nan_triton), (
691+
"NaN pattern mismatch between reference and triton"
692+
)
693+
694+
# Check finite values
695+
finite_mask = torch.isfinite(x_mx_ref) & torch.isfinite(x_mx_t)
696+
if finite_mask.any():
697+
assert torch.equal(x_mx_ref[finite_mask], x_mx_t[finite_mask]), (
698+
"Finite values mismatch"
699+
)
700+
701+
# Check infinity patterns
702+
inf_ref = torch.isinf(x_mx_ref)
703+
inf_triton = torch.isinf(x_mx_t)
704+
assert torch.equal(inf_ref, inf_triton), (
705+
"Infinity pattern mismatch between reference and triton"
706+
)
707+
708+
if inf_ref.any():
709+
assert torch.equal(x_mx_ref[inf_ref], x_mx_t[inf_ref]), (
710+
"Infinity values mismatch"
711+
)
712+
713+
# Check scales using exact comparison
714+
x_s_ref_uint8 = x_s_ref.to(torch.uint8)
715+
x_s_t_uint8 = x_s_t.to(torch.uint8)
716+
assert torch.equal(x_s_t_uint8, x_s_ref_uint8), (
717+
"Scale values mismatch between reference and triton"
718+
)
719+
720+
721+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
722+
@pytest.mark.skipif(
723+
not is_sm_at_least_100() and not is_MI350(),
724+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
725+
)
726+
@pytest.mark.parametrize(
727+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
728+
)
729+
def test_triton_mxfp8_dim0_overflow_underflow(scaling_mode):
730+
"""Test with values near overflow and underflow thresholds."""
731+
# Values near float8_e4m3fn limits
732+
f8_max = torch.finfo(torch.float8_e4m3fn).max # ~448
733+
f8_min = torch.finfo(torch.float8_e4m3fn).tiny # ~1.95e-06
734+
block_size = 32
735+
736+
overflow_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
737+
738+
# Fill first few elements of each row with overflow/underflow values
739+
overflow_vals[0, :4] = torch.tensor(
740+
[f8_max * 0.9, f8_max * 1.1, f8_max * 2.0, f8_max * 10.0], dtype=torch.bfloat16
741+
)
742+
overflow_vals[1, :4] = torch.tensor(
743+
[-f8_max * 0.9, -f8_max * 1.1, -f8_max * 2.0, -f8_max * 10.0],
744+
dtype=torch.bfloat16,
745+
)
746+
overflow_vals[2, :4] = torch.tensor(
747+
[f8_min * 0.1, f8_min * 0.5, f8_min * 2.0, f8_min * 10.0], dtype=torch.bfloat16
748+
)
749+
overflow_vals[3, :4] = torch.tensor(
750+
[-f8_min * 0.1, -f8_min * 0.5, -f8_min * 2.0, -f8_min * 10.0],
751+
dtype=torch.bfloat16,
752+
)
753+
754+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
755+
overflow_vals, block_size=block_size, scaling_mode=scaling_mode
756+
)
757+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
758+
overflow_vals,
759+
inner_block_size=block_size,
760+
scaling_mode=scaling_mode.value.lower(),
761+
)
762+
763+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
764+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
765+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
766+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
767+
768+
769+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
770+
@pytest.mark.skipif(
771+
not is_sm_at_least_100() and not is_MI350(),
772+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
773+
)
774+
@pytest.mark.parametrize(
775+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
776+
)
777+
def test_triton_mxfp8_dim0_extreme_range(scaling_mode):
778+
"""Test with tensors containing both very large and very small values."""
779+
# Mix of extreme values in same tensor to test scaling edge cases
780+
block_size = 32
781+
extreme_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
782+
783+
# Fill first few elements with extreme values
784+
extreme_vals[0, :4] = torch.tensor([1e30, 1e-30, 1e20, 1e-20], dtype=torch.bfloat16)
785+
extreme_vals[1, :4] = torch.tensor(
786+
[-1e30, -1e-30, -1e20, -1e-20], dtype=torch.bfloat16
787+
)
788+
extreme_vals[2, :4] = torch.tensor(
789+
[torch.finfo(torch.float32).max, torch.finfo(torch.float32).tiny, 1.0, -1.0],
790+
dtype=torch.bfloat16,
791+
)
792+
extreme_vals[3, :4] = torch.tensor([0.0, 1e-40, 1e40, -1e40], dtype=torch.bfloat16)
793+
794+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
795+
extreme_vals, block_size=block_size, scaling_mode=scaling_mode
796+
)
797+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
798+
extreme_vals,
799+
inner_block_size=block_size,
800+
scaling_mode=scaling_mode.value.lower(),
801+
)
802+
803+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
804+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
805+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
806+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
807+
808+
809+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
810+
@pytest.mark.skipif(
811+
not is_sm_at_least_100() and not is_MI350(),
812+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
813+
)
814+
@pytest.mark.parametrize("block_size", (1, 2, 4, 8, 16, 32, 64))
815+
@pytest.mark.parametrize(
816+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
817+
)
818+
def test_triton_mxfp8_dim0_edge_block_sizes(block_size, scaling_mode):
819+
"""Test with various block sizes that might expose edge cases."""
820+
# Use size that's divisible by block_size to avoid padding edge cases first
821+
M = max(64, block_size * 2)
822+
K = max(64, block_size * 4)
823+
824+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
825+
826+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
827+
x, block_size=block_size, scaling_mode=scaling_mode
828+
)
829+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
830+
x,
831+
inner_block_size=block_size,
832+
scaling_mode=scaling_mode.value.lower(),
833+
)
834+
835+
assert not x_mx_t.isnan().any(), (
836+
f"quantized tensor should not contain NaNs with block_size={block_size}"
837+
)
838+
assert not x_s_t.isnan().any(), (
839+
f"scales should not contain NaNs with block_size={block_size}"
840+
)
841+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
842+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
843+
844+
845+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
846+
@pytest.mark.skipif(
847+
not is_sm_at_least_100() and not is_MI350(),
848+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
849+
)
850+
@pytest.mark.parametrize(
851+
"shape", [(1, 32), (32, 1), (1, 1), (7, 13), (31, 17), (33, 31)]
852+
)
853+
@pytest.mark.parametrize(
854+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
855+
)
856+
def test_triton_mxfp8_dim0_odd_shapes(shape, scaling_mode):
857+
"""Test with odd tensor shapes that might not align well with block sizes."""
858+
M, K = shape
859+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
860+
block_size = min(32, K) # Adjust block size for small tensors
861+
862+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
863+
x, block_size=block_size, scaling_mode=scaling_mode
864+
)
865+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
866+
x,
867+
inner_block_size=block_size,
868+
scaling_mode=scaling_mode.value.lower(),
869+
)
870+
871+
assert not x_mx_t.isnan().any(), (
872+
f"quantized tensor should not contain NaNs with shape={shape}"
873+
)
874+
assert not x_s_t.isnan().any(), f"scales should not contain NaNs with shape={shape}"
875+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
876+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
877+
878+
879+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
880+
@pytest.mark.skipif(
881+
not is_sm_at_least_100() and not is_MI350(),
882+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
883+
)
884+
@pytest.mark.parametrize(
885+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
886+
)
887+
def test_triton_mxfp8_dim0_denormals_subnormals(scaling_mode):
888+
"""Test with denormal/subnormal values that might cause precision issues."""
889+
# Create values in the denormal range
890+
bf16_tiny = torch.finfo(torch.bfloat16).tiny
891+
f32_tiny = torch.finfo(torch.float32).tiny
892+
block_size = 32
893+
894+
denormal_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
895+
896+
# Fill first few elements with denormal values
897+
denormal_vals[0, :4] = torch.tensor(
898+
[bf16_tiny, bf16_tiny * 0.5, bf16_tiny * 0.1, bf16_tiny * 2.0],
899+
dtype=torch.bfloat16,
900+
)
901+
denormal_vals[1, :4] = torch.tensor(
902+
[f32_tiny, f32_tiny * 0.5, f32_tiny * 0.1, f32_tiny * 2.0], dtype=torch.bfloat16
903+
)
904+
denormal_vals[2, :4] = torch.tensor(
905+
[-bf16_tiny, -bf16_tiny * 0.5, -bf16_tiny * 0.1, -bf16_tiny * 2.0],
906+
dtype=torch.bfloat16,
907+
)
908+
denormal_vals[3, :4] = torch.tensor(
909+
[1e-40, 1e-38, 1e-36, 1e-34], dtype=torch.bfloat16
910+
) # Very small values
911+
912+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
913+
denormal_vals, block_size=block_size, scaling_mode=scaling_mode
914+
)
915+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
916+
denormal_vals,
917+
inner_block_size=block_size,
918+
scaling_mode=scaling_mode.value.lower(),
919+
)
920+
921+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
922+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
923+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
924+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
925+
926+
927+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
928+
@pytest.mark.skipif(
929+
not is_sm_at_least_100() and not is_MI350(),
930+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
931+
)
932+
@pytest.mark.parametrize(
933+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
934+
)
935+
def test_triton_mxfp8_dim0_constant_values(scaling_mode):
936+
"""Test with tensors of constant values to check scale calculation edge cases."""
937+
test_values = [
938+
1.0,
939+
-1.0,
940+
0.5,
941+
-0.5,
942+
2.0,
943+
-2.0,
944+
100.0,
945+
-100.0,
946+
0.01,
947+
-0.01,
948+
torch.finfo(torch.float8_e4m3fn).max,
949+
-torch.finfo(torch.float8_e4m3fn).max,
950+
torch.finfo(torch.float8_e4m3fn).tiny,
951+
-torch.finfo(torch.float8_e4m3fn).tiny,
952+
]
953+
954+
for val in test_values:
955+
x = torch.full((64, 128), val, dtype=torch.bfloat16, device="cuda")
956+
957+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
958+
x, block_size=32, scaling_mode=scaling_mode
959+
)
960+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
961+
x,
962+
inner_block_size=32,
963+
scaling_mode=scaling_mode.value.lower(),
964+
)
965+
966+
assert not x_mx_t.isnan().any(), (
967+
f"quantized tensor should not contain NaNs for constant value {val}"
968+
)
969+
assert not x_s_t.isnan().any(), (
970+
f"scales should not contain NaNs for constant value {val}"
971+
)
972+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
973+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
974+
975+
976+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
977+
@pytest.mark.skipif(
978+
not is_sm_at_least_100() and not is_MI350(),
979+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
980+
)
981+
@pytest.mark.parametrize(
982+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
983+
)
984+
def test_triton_mxfp8_dim0_alternating_signs(scaling_mode):
985+
"""Test with alternating positive/negative patterns that might cause scaling issues."""
986+
M, K = 64, 128
987+
988+
# Create alternating positive/negative pattern
989+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").abs()
990+
x[::2] *= -1 # Make every other row negative
991+
992+
# Also test checkerboard pattern
993+
checkerboard = torch.ones(M, K, dtype=torch.bfloat16, device="cuda")
994+
checkerboard[::2, ::2] *= -1
995+
checkerboard[1::2, 1::2] *= -1
996+
x_checkerboard = (
997+
torch.randn(M, K, dtype=torch.bfloat16, device="cuda").abs() * checkerboard
998+
)
999+
1000+
for x_test, name in [(x, "alternating_rows"), (x_checkerboard, "checkerboard")]:
1001+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
1002+
x_test, block_size=32, scaling_mode=scaling_mode
1003+
)
1004+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
1005+
x_test,
1006+
inner_block_size=32,
1007+
scaling_mode=scaling_mode.value.lower(),
1008+
)
1009+
1010+
assert not x_mx_t.isnan().any(), (
1011+
f"quantized tensor should not contain NaNs for {name} pattern"
1012+
)
1013+
assert not x_s_t.isnan().any(), (
1014+
f"scales should not contain NaNs for {name} pattern"
1015+
)
1016+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
1017+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

0 commit comments

Comments
 (0)