@@ -625,3 +625,374 @@ def test_cuda_mx_dim0_not_supported():
625625 rowwise = True ,
626626 colwise = False ,
627627 )
628+
629+
630+ # Additional comprehensive tests for triton_to_mxfp8_dim0 to debug NaN issues
631+
632+
633+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
634+ @pytest .mark .skipif (
635+ not is_sm_at_least_100 () and not is_MI350 (),
636+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
637+ )
638+ @pytest .mark .parametrize (
639+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
640+ )
641+ def test_triton_mxfp8_dim0_special_values (scaling_mode ):
642+ """Test with special IEEE 754 values that commonly cause NaN issues."""
643+ # Create tensor with special values - make it compatible with block_size=32
644+ block_size = 32
645+ special_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
646+
647+ # Fill first few elements of each row with special values
648+ special_vals [0 , :4 ] = torch .tensor (
649+ [float ("inf" ), - float ("inf" ), float ("nan" ), 0.0 ], dtype = torch .bfloat16
650+ )
651+ special_vals [1 , :4 ] = torch .tensor ([1.0 , - 1.0 , 2.0 , - 2.0 ], dtype = torch .bfloat16 )
652+ special_vals [2 , :4 ] = torch .tensor (
653+ [1e10 , - 1e10 , 1e-10 , - 1e-10 ], dtype = torch .bfloat16
654+ )
655+ special_vals [3 , :4 ] = torch .tensor (
656+ [
657+ torch .finfo (torch .float32 ).max ,
658+ torch .finfo (torch .float32 ).min ,
659+ torch .finfo (torch .float32 ).tiny ,
660+ - torch .finfo (torch .float32 ).tiny ,
661+ ],
662+ dtype = torch .bfloat16 ,
663+ )
664+
665+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
666+ special_vals , block_size = block_size , scaling_mode = scaling_mode
667+ )
668+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
669+ special_vals ,
670+ inner_block_size = block_size ,
671+ scaling_mode = scaling_mode .value .lower (),
672+ )
673+ x_mx_t = x_mx_t .to (torch .float32 )
674+ x_s_t = x_s_t .to (torch .uint8 )
675+ x_mx_ref = x_mx_ref .to (torch .float32 )
676+ x_s_ref = x_s_ref .to (torch .uint8 )
677+
678+ # Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
679+ input_has_nan = special_vals .isnan ().any ()
680+ if not input_has_nan :
681+ assert not x_mx_t .isnan ().any (), (
682+ "quantized tensor should not contain NaNs when input has no NaNs"
683+ )
684+ assert not x_s_t .isnan ().any (), (
685+ "scales should not contain NaNs when input has no NaNs"
686+ )
687+
688+ # Compare outputs where both are finite
689+ finite_mask = torch .isfinite (x_mx_ref ) & torch .isfinite (x_mx_t )
690+ if finite_mask .any ():
691+ torch .testing .assert_close (
692+ x_mx_t [finite_mask ], x_mx_ref [finite_mask ], rtol = 0 , atol = 0
693+ )
694+
695+ scale_finite_mask = torch .isfinite (x_s_ref ) & torch .isfinite (x_s_t )
696+ if scale_finite_mask .any ():
697+ torch .testing .assert_close (
698+ x_s_t [scale_finite_mask ], x_s_ref [scale_finite_mask ], rtol = 0 , atol = 0
699+ )
700+
701+
702+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
703+ @pytest .mark .skipif (
704+ not is_sm_at_least_100 () and not is_MI350 (),
705+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
706+ )
707+ @pytest .mark .parametrize (
708+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
709+ )
710+ def test_triton_mxfp8_dim0_overflow_underflow (scaling_mode ):
711+ """Test with values near overflow and underflow thresholds."""
712+ # Values near float8_e4m3fn limits
713+ f8_max = torch .finfo (torch .float8_e4m3fn ).max # ~448
714+ f8_min = torch .finfo (torch .float8_e4m3fn ).tiny # ~1.95e-06
715+ block_size = 32
716+
717+ overflow_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
718+
719+ # Fill first few elements of each row with overflow/underflow values
720+ overflow_vals [0 , :4 ] = torch .tensor (
721+ [f8_max * 0.9 , f8_max * 1.1 , f8_max * 2.0 , f8_max * 10.0 ], dtype = torch .bfloat16
722+ )
723+ overflow_vals [1 , :4 ] = torch .tensor (
724+ [- f8_max * 0.9 , - f8_max * 1.1 , - f8_max * 2.0 , - f8_max * 10.0 ],
725+ dtype = torch .bfloat16 ,
726+ )
727+ overflow_vals [2 , :4 ] = torch .tensor (
728+ [f8_min * 0.1 , f8_min * 0.5 , f8_min * 2.0 , f8_min * 10.0 ], dtype = torch .bfloat16
729+ )
730+ overflow_vals [3 , :4 ] = torch .tensor (
731+ [- f8_min * 0.1 , - f8_min * 0.5 , - f8_min * 2.0 , - f8_min * 10.0 ],
732+ dtype = torch .bfloat16 ,
733+ )
734+
735+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
736+ overflow_vals , block_size = block_size , scaling_mode = scaling_mode
737+ )
738+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
739+ overflow_vals ,
740+ inner_block_size = block_size ,
741+ scaling_mode = scaling_mode .value .lower (),
742+ )
743+
744+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
745+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
746+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
747+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
748+
749+
750+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
751+ @pytest .mark .skipif (
752+ not is_sm_at_least_100 () and not is_MI350 (),
753+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
754+ )
755+ @pytest .mark .parametrize (
756+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
757+ )
758+ def test_triton_mxfp8_dim0_extreme_range (scaling_mode ):
759+ """Test with tensors containing both very large and very small values."""
760+ # Mix of extreme values in same tensor to test scaling edge cases
761+ block_size = 32
762+ extreme_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
763+
764+ # Fill first few elements with extreme values
765+ extreme_vals [0 , :4 ] = torch .tensor ([1e30 , 1e-30 , 1e20 , 1e-20 ], dtype = torch .bfloat16 )
766+ extreme_vals [1 , :4 ] = torch .tensor (
767+ [- 1e30 , - 1e-30 , - 1e20 , - 1e-20 ], dtype = torch .bfloat16
768+ )
769+ extreme_vals [2 , :4 ] = torch .tensor (
770+ [torch .finfo (torch .float32 ).max , torch .finfo (torch .float32 ).tiny , 1.0 , - 1.0 ],
771+ dtype = torch .bfloat16 ,
772+ )
773+ extreme_vals [3 , :4 ] = torch .tensor ([0.0 , 1e-40 , 1e40 , - 1e40 ], dtype = torch .bfloat16 )
774+
775+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
776+ extreme_vals , block_size = block_size , scaling_mode = scaling_mode
777+ )
778+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
779+ extreme_vals ,
780+ inner_block_size = block_size ,
781+ scaling_mode = scaling_mode .value .lower (),
782+ )
783+
784+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
785+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
786+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
787+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
788+
789+
790+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
791+ @pytest .mark .skipif (
792+ not is_sm_at_least_100 () and not is_MI350 (),
793+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
794+ )
795+ @pytest .mark .parametrize ("block_size" , (1 , 2 , 4 , 8 , 16 , 32 , 64 ))
796+ @pytest .mark .parametrize (
797+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
798+ )
799+ def test_triton_mxfp8_dim0_edge_block_sizes (block_size , scaling_mode ):
800+ """Test with various block sizes that might expose edge cases."""
801+ # Use size that's divisible by block_size to avoid padding edge cases first
802+ M = max (64 , block_size * 2 )
803+ K = max (64 , block_size * 4 )
804+
805+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
806+
807+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
808+ x , block_size = block_size , scaling_mode = scaling_mode
809+ )
810+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
811+ x ,
812+ inner_block_size = block_size ,
813+ scaling_mode = scaling_mode .value .lower (),
814+ )
815+
816+ assert not x_mx_t .isnan ().any (), (
817+ f"quantized tensor should not contain NaNs with block_size={ block_size } "
818+ )
819+ assert not x_s_t .isnan ().any (), (
820+ f"scales should not contain NaNs with block_size={ block_size } "
821+ )
822+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
823+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
824+
825+
826+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
827+ @pytest .mark .skipif (
828+ not is_sm_at_least_100 () and not is_MI350 (),
829+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
830+ )
831+ @pytest .mark .parametrize (
832+ "shape" , [(1 , 32 ), (32 , 1 ), (1 , 1 ), (7 , 13 ), (31 , 17 ), (33 , 31 )]
833+ )
834+ @pytest .mark .parametrize (
835+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
836+ )
837+ def test_triton_mxfp8_dim0_odd_shapes (shape , scaling_mode ):
838+ """Test with odd tensor shapes that might not align well with block sizes."""
839+ M , K = shape
840+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
841+ block_size = min (32 , K ) # Adjust block size for small tensors
842+
843+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
844+ x , block_size = block_size , scaling_mode = scaling_mode
845+ )
846+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
847+ x ,
848+ inner_block_size = block_size ,
849+ scaling_mode = scaling_mode .value .lower (),
850+ )
851+
852+ assert not x_mx_t .isnan ().any (), (
853+ f"quantized tensor should not contain NaNs with shape={ shape } "
854+ )
855+ assert not x_s_t .isnan ().any (), f"scales should not contain NaNs with shape={ shape } "
856+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
857+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
858+
859+
860+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
861+ @pytest .mark .skipif (
862+ not is_sm_at_least_100 () and not is_MI350 (),
863+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
864+ )
865+ @pytest .mark .parametrize (
866+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
867+ )
868+ def test_triton_mxfp8_dim0_denormals_subnormals (scaling_mode ):
869+ """Test with denormal/subnormal values that might cause precision issues."""
870+ # Create values in the denormal range
871+ bf16_tiny = torch .finfo (torch .bfloat16 ).tiny
872+ f32_tiny = torch .finfo (torch .float32 ).tiny
873+ block_size = 32
874+
875+ denormal_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
876+
877+ # Fill first few elements with denormal values
878+ denormal_vals [0 , :4 ] = torch .tensor (
879+ [bf16_tiny , bf16_tiny * 0.5 , bf16_tiny * 0.1 , bf16_tiny * 2.0 ],
880+ dtype = torch .bfloat16 ,
881+ )
882+ denormal_vals [1 , :4 ] = torch .tensor (
883+ [f32_tiny , f32_tiny * 0.5 , f32_tiny * 0.1 , f32_tiny * 2.0 ], dtype = torch .bfloat16
884+ )
885+ denormal_vals [2 , :4 ] = torch .tensor (
886+ [- bf16_tiny , - bf16_tiny * 0.5 , - bf16_tiny * 0.1 , - bf16_tiny * 2.0 ],
887+ dtype = torch .bfloat16 ,
888+ )
889+ denormal_vals [3 , :4 ] = torch .tensor (
890+ [1e-40 , 1e-38 , 1e-36 , 1e-34 ], dtype = torch .bfloat16
891+ ) # Very small values
892+
893+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
894+ denormal_vals , block_size = block_size , scaling_mode = scaling_mode
895+ )
896+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
897+ denormal_vals ,
898+ inner_block_size = block_size ,
899+ scaling_mode = scaling_mode .value .lower (),
900+ )
901+
902+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
903+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
904+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
905+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
906+
907+
908+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
909+ @pytest .mark .skipif (
910+ not is_sm_at_least_100 () and not is_MI350 (),
911+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
912+ )
913+ @pytest .mark .parametrize (
914+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
915+ )
916+ def test_triton_mxfp8_dim0_constant_values (scaling_mode ):
917+ """Test with tensors of constant values to check scale calculation edge cases."""
918+ test_values = [
919+ 1.0 ,
920+ - 1.0 ,
921+ 0.5 ,
922+ - 0.5 ,
923+ 2.0 ,
924+ - 2.0 ,
925+ 100.0 ,
926+ - 100.0 ,
927+ 0.01 ,
928+ - 0.01 ,
929+ torch .finfo (torch .float8_e4m3fn ).max ,
930+ - torch .finfo (torch .float8_e4m3fn ).max ,
931+ torch .finfo (torch .float8_e4m3fn ).tiny ,
932+ - torch .finfo (torch .float8_e4m3fn ).tiny ,
933+ ]
934+
935+ for val in test_values :
936+ x = torch .full ((64 , 128 ), val , dtype = torch .bfloat16 , device = "cuda" )
937+
938+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
939+ x , block_size = 32 , scaling_mode = scaling_mode
940+ )
941+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
942+ x ,
943+ inner_block_size = 32 ,
944+ scaling_mode = scaling_mode .value .lower (),
945+ )
946+
947+ assert not x_mx_t .isnan ().any (), (
948+ f"quantized tensor should not contain NaNs for constant value { val } "
949+ )
950+ assert not x_s_t .isnan ().any (), (
951+ f"scales should not contain NaNs for constant value { val } "
952+ )
953+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
954+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
955+
956+
957+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
958+ @pytest .mark .skipif (
959+ not is_sm_at_least_100 () and not is_MI350 (),
960+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
961+ )
962+ @pytest .mark .parametrize (
963+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
964+ )
965+ def test_triton_mxfp8_dim0_alternating_signs (scaling_mode ):
966+ """Test with alternating positive/negative patterns that might cause scaling issues."""
967+ M , K = 64 , 128
968+
969+ # Create alternating positive/negative pattern
970+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ).abs ()
971+ x [::2 ] *= - 1 # Make every other row negative
972+
973+ # Also test checkerboard pattern
974+ checkerboard = torch .ones (M , K , dtype = torch .bfloat16 , device = "cuda" )
975+ checkerboard [::2 , ::2 ] *= - 1
976+ checkerboard [1 ::2 , 1 ::2 ] *= - 1
977+ x_checkerboard = (
978+ torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ).abs () * checkerboard
979+ )
980+
981+ for x_test , name in [(x , "alternating_rows" ), (x_checkerboard , "checkerboard" )]:
982+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
983+ x_test , block_size = 32 , scaling_mode = scaling_mode
984+ )
985+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
986+ x_test ,
987+ inner_block_size = 32 ,
988+ scaling_mode = scaling_mode .value .lower (),
989+ )
990+
991+ assert not x_mx_t .isnan ().any (), (
992+ f"quantized tensor should not contain NaNs for { name } pattern"
993+ )
994+ assert not x_s_t .isnan ().any (), (
995+ f"scales should not contain NaNs for { name } pattern"
996+ )
997+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
998+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
0 commit comments