@@ -625,3 +625,393 @@ def test_cuda_mx_dim0_not_supported():
625625 rowwise = True ,
626626 colwise = False ,
627627 )
628+
629+
630+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
631+ @pytest .mark .skipif (
632+ not is_sm_at_least_100 () and not is_MI350 (),
633+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
634+ )
635+ def test_triton_mxfp8_dim0_special_values ():
636+ """Test with special IEEE 754 values that commonly cause NaN issues."""
637+ # Test only RCEIL mode to match canonical PyTorch behavior
638+ scaling_mode = ScaleCalculationMode .RCEIL
639+
640+ # Create tensor with special values - make it compatible with block_size=32
641+ block_size = 32
642+ special_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
643+
644+ # Fill first few elements of each row with special values
645+ special_vals [0 , :4 ] = torch .tensor (
646+ [float ("inf" ), - float ("inf" ), float ("nan" ), 0.0 ], dtype = torch .bfloat16
647+ )
648+ special_vals [1 , :4 ] = torch .tensor ([1.0 , - 1.0 , 2.0 , - 2.0 ], dtype = torch .bfloat16 )
649+ special_vals [2 , :4 ] = torch .tensor (
650+ [1e10 , - 1e10 , 1e-10 , - 1e-10 ], dtype = torch .bfloat16
651+ )
652+ special_vals [3 , :4 ] = torch .tensor (
653+ [
654+ torch .finfo (torch .float32 ).max ,
655+ torch .finfo (torch .float32 ).min ,
656+ torch .finfo (torch .float32 ).tiny ,
657+ - torch .finfo (torch .float32 ).tiny ,
658+ ],
659+ dtype = torch .bfloat16 ,
660+ )
661+
662+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
663+ special_vals , block_size = block_size , scaling_mode = scaling_mode
664+ )
665+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
666+ special_vals ,
667+ inner_block_size = block_size ,
668+ scaling_mode = scaling_mode .value .lower (),
669+ )
670+ x_mx_t = x_mx_t .to (torch .float32 )
671+ x_s_t = x_s_t .to (torch .uint8 )
672+ x_mx_ref = x_mx_ref .to (torch .float32 )
673+ x_s_ref = x_s_ref .to (torch .uint8 )
674+
675+ # Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
676+ input_has_nan = special_vals .isnan ().any ()
677+ if not input_has_nan :
678+ assert not x_mx_t .isnan ().any (), (
679+ "quantized tensor should not contain NaNs when input has no NaNs"
680+ )
681+ assert not x_s_t .isnan ().any (), (
682+ "scales should not contain NaNs when input has no NaNs"
683+ )
684+
685+ # Use NaN-aware comparison to handle nan != nan case properly
686+
687+ # Check NaN patterns match
688+ nan_ref = torch .isnan (x_mx_ref )
689+ nan_triton = torch .isnan (x_mx_t )
690+ assert torch .equal (nan_ref , nan_triton ), (
691+ "NaN pattern mismatch between reference and triton"
692+ )
693+
694+ # Check finite values
695+ finite_mask = torch .isfinite (x_mx_ref ) & torch .isfinite (x_mx_t )
696+ if finite_mask .any ():
697+ assert torch .equal (x_mx_ref [finite_mask ], x_mx_t [finite_mask ]), (
698+ "Finite values mismatch"
699+ )
700+
701+ # Check infinity patterns
702+ inf_ref = torch .isinf (x_mx_ref )
703+ inf_triton = torch .isinf (x_mx_t )
704+ assert torch .equal (inf_ref , inf_triton ), (
705+ "Infinity pattern mismatch between reference and triton"
706+ )
707+
708+ if inf_ref .any ():
709+ assert torch .equal (x_mx_ref [inf_ref ], x_mx_t [inf_ref ]), (
710+ "Infinity values mismatch"
711+ )
712+
713+ # Check scales using exact comparison
714+ x_s_ref_uint8 = x_s_ref .to (torch .uint8 )
715+ x_s_t_uint8 = x_s_t .to (torch .uint8 )
716+ assert torch .equal (x_s_t_uint8 , x_s_ref_uint8 ), (
717+ "Scale values mismatch between reference and triton"
718+ )
719+
720+
721+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
722+ @pytest .mark .skipif (
723+ not is_sm_at_least_100 () and not is_MI350 (),
724+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
725+ )
726+ @pytest .mark .parametrize (
727+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
728+ )
729+ def test_triton_mxfp8_dim0_overflow_underflow (scaling_mode ):
730+ """Test with values near overflow and underflow thresholds."""
731+ # Values near float8_e4m3fn limits
732+ f8_max = torch .finfo (torch .float8_e4m3fn ).max # ~448
733+ f8_min = torch .finfo (torch .float8_e4m3fn ).tiny # ~1.95e-06
734+ block_size = 32
735+
736+ overflow_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
737+
738+ # Fill first few elements of each row with overflow/underflow values
739+ overflow_vals [0 , :4 ] = torch .tensor (
740+ [f8_max * 0.9 , f8_max * 1.1 , f8_max * 2.0 , f8_max * 10.0 ], dtype = torch .bfloat16
741+ )
742+ overflow_vals [1 , :4 ] = torch .tensor (
743+ [- f8_max * 0.9 , - f8_max * 1.1 , - f8_max * 2.0 , - f8_max * 10.0 ],
744+ dtype = torch .bfloat16 ,
745+ )
746+ overflow_vals [2 , :4 ] = torch .tensor (
747+ [f8_min * 0.1 , f8_min * 0.5 , f8_min * 2.0 , f8_min * 10.0 ], dtype = torch .bfloat16
748+ )
749+ overflow_vals [3 , :4 ] = torch .tensor (
750+ [- f8_min * 0.1 , - f8_min * 0.5 , - f8_min * 2.0 , - f8_min * 10.0 ],
751+ dtype = torch .bfloat16 ,
752+ )
753+
754+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
755+ overflow_vals , block_size = block_size , scaling_mode = scaling_mode
756+ )
757+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
758+ overflow_vals ,
759+ inner_block_size = block_size ,
760+ scaling_mode = scaling_mode .value .lower (),
761+ )
762+
763+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
764+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
765+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
766+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
767+
768+
769+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
770+ @pytest .mark .skipif (
771+ not is_sm_at_least_100 () and not is_MI350 (),
772+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
773+ )
774+ @pytest .mark .parametrize (
775+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
776+ )
777+ def test_triton_mxfp8_dim0_extreme_range (scaling_mode ):
778+ """Test with tensors containing both very large and very small values."""
779+ # Mix of extreme values in same tensor to test scaling edge cases
780+ block_size = 32
781+ extreme_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
782+
783+ # Fill first few elements with extreme values
784+ extreme_vals [0 , :4 ] = torch .tensor ([1e30 , 1e-30 , 1e20 , 1e-20 ], dtype = torch .bfloat16 )
785+ extreme_vals [1 , :4 ] = torch .tensor (
786+ [- 1e30 , - 1e-30 , - 1e20 , - 1e-20 ], dtype = torch .bfloat16
787+ )
788+ extreme_vals [2 , :4 ] = torch .tensor (
789+ [torch .finfo (torch .float32 ).max , torch .finfo (torch .float32 ).tiny , 1.0 , - 1.0 ],
790+ dtype = torch .bfloat16 ,
791+ )
792+ extreme_vals [3 , :4 ] = torch .tensor ([0.0 , 1e-40 , 1e40 , - 1e40 ], dtype = torch .bfloat16 )
793+
794+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
795+ extreme_vals , block_size = block_size , scaling_mode = scaling_mode
796+ )
797+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
798+ extreme_vals ,
799+ inner_block_size = block_size ,
800+ scaling_mode = scaling_mode .value .lower (),
801+ )
802+
803+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
804+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
805+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
806+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
807+
808+
809+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
810+ @pytest .mark .skipif (
811+ not is_sm_at_least_100 () and not is_MI350 (),
812+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
813+ )
814+ @pytest .mark .parametrize ("block_size" , (1 , 2 , 4 , 8 , 16 , 32 , 64 ))
815+ @pytest .mark .parametrize (
816+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
817+ )
818+ def test_triton_mxfp8_dim0_edge_block_sizes (block_size , scaling_mode ):
819+ """Test with various block sizes that might expose edge cases."""
820+ # Use size that's divisible by block_size to avoid padding edge cases first
821+ M = max (64 , block_size * 2 )
822+ K = max (64 , block_size * 4 )
823+
824+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
825+
826+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
827+ x , block_size = block_size , scaling_mode = scaling_mode
828+ )
829+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
830+ x ,
831+ inner_block_size = block_size ,
832+ scaling_mode = scaling_mode .value .lower (),
833+ )
834+
835+ assert not x_mx_t .isnan ().any (), (
836+ f"quantized tensor should not contain NaNs with block_size={ block_size } "
837+ )
838+ assert not x_s_t .isnan ().any (), (
839+ f"scales should not contain NaNs with block_size={ block_size } "
840+ )
841+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
842+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
843+
844+
845+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
846+ @pytest .mark .skipif (
847+ not is_sm_at_least_100 () and not is_MI350 (),
848+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
849+ )
850+ @pytest .mark .parametrize (
851+ "shape" , [(1 , 32 ), (32 , 1 ), (1 , 1 ), (7 , 13 ), (31 , 17 ), (33 , 31 )]
852+ )
853+ @pytest .mark .parametrize (
854+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
855+ )
856+ def test_triton_mxfp8_dim0_odd_shapes (shape , scaling_mode ):
857+ """Test with odd tensor shapes that might not align well with block sizes."""
858+ M , K = shape
859+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
860+ block_size = min (32 , K ) # Adjust block size for small tensors
861+
862+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
863+ x , block_size = block_size , scaling_mode = scaling_mode
864+ )
865+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
866+ x ,
867+ inner_block_size = block_size ,
868+ scaling_mode = scaling_mode .value .lower (),
869+ )
870+
871+ assert not x_mx_t .isnan ().any (), (
872+ f"quantized tensor should not contain NaNs with shape={ shape } "
873+ )
874+ assert not x_s_t .isnan ().any (), f"scales should not contain NaNs with shape={ shape } "
875+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
876+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
877+
878+
879+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
880+ @pytest .mark .skipif (
881+ not is_sm_at_least_100 () and not is_MI350 (),
882+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
883+ )
884+ @pytest .mark .parametrize (
885+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
886+ )
887+ def test_triton_mxfp8_dim0_denormals_subnormals (scaling_mode ):
888+ """Test with denormal/subnormal values that might cause precision issues."""
889+ # Create values in the denormal range
890+ bf16_tiny = torch .finfo (torch .bfloat16 ).tiny
891+ f32_tiny = torch .finfo (torch .float32 ).tiny
892+ block_size = 32
893+
894+ denormal_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
895+
896+ # Fill first few elements with denormal values
897+ denormal_vals [0 , :4 ] = torch .tensor (
898+ [bf16_tiny , bf16_tiny * 0.5 , bf16_tiny * 0.1 , bf16_tiny * 2.0 ],
899+ dtype = torch .bfloat16 ,
900+ )
901+ denormal_vals [1 , :4 ] = torch .tensor (
902+ [f32_tiny , f32_tiny * 0.5 , f32_tiny * 0.1 , f32_tiny * 2.0 ], dtype = torch .bfloat16
903+ )
904+ denormal_vals [2 , :4 ] = torch .tensor (
905+ [- bf16_tiny , - bf16_tiny * 0.5 , - bf16_tiny * 0.1 , - bf16_tiny * 2.0 ],
906+ dtype = torch .bfloat16 ,
907+ )
908+ denormal_vals [3 , :4 ] = torch .tensor (
909+ [1e-40 , 1e-38 , 1e-36 , 1e-34 ], dtype = torch .bfloat16
910+ ) # Very small values
911+
912+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
913+ denormal_vals , block_size = block_size , scaling_mode = scaling_mode
914+ )
915+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
916+ denormal_vals ,
917+ inner_block_size = block_size ,
918+ scaling_mode = scaling_mode .value .lower (),
919+ )
920+
921+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
922+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
923+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
924+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
925+
926+
927+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
928+ @pytest .mark .skipif (
929+ not is_sm_at_least_100 () and not is_MI350 (),
930+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
931+ )
932+ @pytest .mark .parametrize (
933+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
934+ )
935+ def test_triton_mxfp8_dim0_constant_values (scaling_mode ):
936+ """Test with tensors of constant values to check scale calculation edge cases."""
937+ test_values = [
938+ 1.0 ,
939+ - 1.0 ,
940+ 0.5 ,
941+ - 0.5 ,
942+ 2.0 ,
943+ - 2.0 ,
944+ 100.0 ,
945+ - 100.0 ,
946+ 0.01 ,
947+ - 0.01 ,
948+ torch .finfo (torch .float8_e4m3fn ).max ,
949+ - torch .finfo (torch .float8_e4m3fn ).max ,
950+ torch .finfo (torch .float8_e4m3fn ).tiny ,
951+ - torch .finfo (torch .float8_e4m3fn ).tiny ,
952+ ]
953+
954+ for val in test_values :
955+ x = torch .full ((64 , 128 ), val , dtype = torch .bfloat16 , device = "cuda" )
956+
957+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
958+ x , block_size = 32 , scaling_mode = scaling_mode
959+ )
960+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
961+ x ,
962+ inner_block_size = 32 ,
963+ scaling_mode = scaling_mode .value .lower (),
964+ )
965+
966+ assert not x_mx_t .isnan ().any (), (
967+ f"quantized tensor should not contain NaNs for constant value { val } "
968+ )
969+ assert not x_s_t .isnan ().any (), (
970+ f"scales should not contain NaNs for constant value { val } "
971+ )
972+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
973+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
974+
975+
976+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
977+ @pytest .mark .skipif (
978+ not is_sm_at_least_100 () and not is_MI350 (),
979+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
980+ )
981+ @pytest .mark .parametrize (
982+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
983+ )
984+ def test_triton_mxfp8_dim0_alternating_signs (scaling_mode ):
985+ """Test with alternating positive/negative patterns that might cause scaling issues."""
986+ M , K = 64 , 128
987+
988+ # Create alternating positive/negative pattern
989+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ).abs ()
990+ x [::2 ] *= - 1 # Make every other row negative
991+
992+ # Also test checkerboard pattern
993+ checkerboard = torch .ones (M , K , dtype = torch .bfloat16 , device = "cuda" )
994+ checkerboard [::2 , ::2 ] *= - 1
995+ checkerboard [1 ::2 , 1 ::2 ] *= - 1
996+ x_checkerboard = (
997+ torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ).abs () * checkerboard
998+ )
999+
1000+ for x_test , name in [(x , "alternating_rows" ), (x_checkerboard , "checkerboard" )]:
1001+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
1002+ x_test , block_size = 32 , scaling_mode = scaling_mode
1003+ )
1004+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
1005+ x_test ,
1006+ inner_block_size = 32 ,
1007+ scaling_mode = scaling_mode .value .lower (),
1008+ )
1009+
1010+ assert not x_mx_t .isnan ().any (), (
1011+ f"quantized tensor should not contain NaNs for { name } pattern"
1012+ )
1013+ assert not x_s_t .isnan ().any (), (
1014+ f"scales should not contain NaNs for { name } pattern"
1015+ )
1016+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
1017+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
0 commit comments