@@ -472,9 +472,7 @@ def test_triton_mxfp8_dim1_randn(M, K, scaling_mode):
472472)
473473@pytest .mark .parametrize ("M" , (128 , 256 ))
474474@pytest .mark .parametrize ("K" , (128 , 256 ))
475- @pytest .mark .parametrize (
476- "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
477- )
475+ @pytest .mark .parametrize ("scaling_mode" , (ScaleCalculationMode .RCEIL ,))
478476def test_triton_mxfp8_dim0_randn (M , K , scaling_mode ):
479477 x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
480478 x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
@@ -625,3 +623,215 @@ def test_cuda_mx_dim0_not_supported():
625623 rowwise = True ,
626624 colwise = False ,
627625 )
626+
627+
628+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
629+ @pytest .mark .skipif (
630+ not is_sm_at_least_100 () and not is_MI350 (),
631+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
632+ )
633+ @pytest .mark .parametrize ("scaling_mode" , (ScaleCalculationMode .RCEIL ,))
634+ def test_triton_mxfp8_dim0_special_values (scaling_mode : ScaleCalculationMode ):
635+ # Create tensor with special values - make it compatible with block_size=32
636+ block_size = 32
637+ special_vals = torch .zeros (2 , block_size , dtype = torch .bfloat16 , device = "cuda" )
638+
639+ # Fill first few elements of each row with special values
640+ special_vals [0 , :4 ] = torch .tensor (
641+ [float ("inf" ), - float ("inf" ), float ("nan" ), 0.0 ], dtype = torch .bfloat16
642+ )
643+ special_vals [1 , :4 ] = torch .tensor (
644+ [
645+ torch .finfo (torch .float32 ).max ,
646+ torch .finfo (torch .float32 ).min ,
647+ torch .finfo (torch .float32 ).tiny ,
648+ - torch .finfo (torch .float32 ).tiny ,
649+ ],
650+ dtype = torch .bfloat16 ,
651+ )
652+
653+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
654+ special_vals , block_size = block_size , scaling_mode = scaling_mode
655+ )
656+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
657+ special_vals ,
658+ inner_block_size = block_size ,
659+ scaling_mode = scaling_mode .value .lower (),
660+ )
661+ x_mx_t = x_mx_t .to (torch .float32 )
662+ x_s_t = x_s_t .to (torch .uint8 )
663+ x_mx_ref = x_mx_ref .to (torch .float32 )
664+ x_s_ref = x_s_ref .to (torch .uint8 )
665+
666+ # Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
667+ input_has_nan = special_vals .isnan ().any ()
668+ if not input_has_nan :
669+ assert not x_mx_t .isnan ().any (), (
670+ "quantized tensor should not contain NaNs when input has no NaNs"
671+ )
672+ assert not x_s_t .isnan ().any (), (
673+ "scales should not contain NaNs when input has no NaNs"
674+ )
675+
676+ # Use NaN-aware comparison to handle nan != nan case properly
677+ # Check NaN patterns match
678+ nan_ref = torch .isnan (x_mx_ref )
679+ nan_triton = torch .isnan (x_mx_t )
680+ assert torch .equal (nan_ref , nan_triton ), (
681+ "NaN pattern mismatch between reference and triton"
682+ )
683+
684+ # Check finite values
685+ finite_mask = torch .isfinite (x_mx_ref ) & torch .isfinite (x_mx_t )
686+ if finite_mask .any ():
687+ assert torch .equal (x_mx_ref [finite_mask ], x_mx_t [finite_mask ]), (
688+ "Finite values mismatch"
689+ )
690+
691+ # Check infinity patterns
692+ inf_ref = torch .isinf (x_mx_ref )
693+ inf_triton = torch .isinf (x_mx_t )
694+ assert torch .equal (inf_ref , inf_triton ), (
695+ "Infinity pattern mismatch between reference and triton"
696+ )
697+ if inf_ref .any ():
698+ assert torch .equal (x_mx_ref [inf_ref ], x_mx_t [inf_ref ]), (
699+ "Infinity values mismatch"
700+ )
701+
702+ # Check scales using exact comparison
703+ x_s_ref_uint8 = x_s_ref .to (torch .uint8 )
704+ x_s_t_uint8 = x_s_t .to (torch .uint8 )
705+ assert torch .equal (x_s_t_uint8 , x_s_ref_uint8 ), (
706+ "Scale values mismatch between reference and triton"
707+ )
708+
709+
710+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
711+ @pytest .mark .skipif (
712+ not is_sm_at_least_100 () and not is_MI350 (),
713+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
714+ )
715+ @pytest .mark .parametrize ("scaling_mode" , (ScaleCalculationMode .RCEIL ,))
716+ def test_triton_mxfp8_dim0_overflow_underflow (scaling_mode ):
717+ """Test with values near overflow and underflow thresholds."""
718+ # Values near float8_e4m3fn limits
719+ f8_max = torch .finfo (torch .float8_e4m3fn ).max # ~448
720+ f8_min = torch .finfo (torch .float8_e4m3fn ).tiny # ~1.95e-06
721+ block_size = 32
722+
723+ overflow_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
724+
725+ # Fill first few elements of each row with overflow/underflow values
726+ overflow_vals [0 , :4 ] = torch .tensor (
727+ [f8_max * 0.9 , f8_max * 1.1 , f8_max * 2.0 , f8_max * 10.0 ], dtype = torch .bfloat16
728+ )
729+ overflow_vals [1 , :4 ] = torch .tensor (
730+ [- f8_max * 0.9 , - f8_max * 1.1 , - f8_max * 2.0 , - f8_max * 10.0 ],
731+ dtype = torch .bfloat16 ,
732+ )
733+ overflow_vals [2 , :4 ] = torch .tensor (
734+ [f8_min * 0.1 , f8_min * 0.5 , f8_min * 2.0 , f8_min * 10.0 ], dtype = torch .bfloat16
735+ )
736+ overflow_vals [3 , :4 ] = torch .tensor (
737+ [- f8_min * 0.1 , - f8_min * 0.5 , - f8_min * 2.0 , - f8_min * 10.0 ],
738+ dtype = torch .bfloat16 ,
739+ )
740+
741+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
742+ overflow_vals , block_size = block_size , scaling_mode = scaling_mode
743+ )
744+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
745+ overflow_vals ,
746+ inner_block_size = block_size ,
747+ scaling_mode = scaling_mode .value .lower (),
748+ )
749+
750+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
751+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
752+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
753+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
754+
755+
756+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
757+ @pytest .mark .skipif (
758+ not is_sm_at_least_100 () and not is_MI350 (),
759+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
760+ )
761+ @pytest .mark .parametrize ("scaling_mode" , (ScaleCalculationMode .RCEIL ,))
762+ def test_triton_mxfp8_dim0_extreme_range (scaling_mode ):
763+ """Test with tensors containing both very large and very small values."""
764+ # Mix of extreme values in same tensor to test scaling edge cases
765+ block_size = 32
766+ extreme_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
767+
768+ # Fill first few elements with extreme values
769+ extreme_vals [0 , :4 ] = torch .tensor ([1e30 , 1e-30 , 1e20 , 1e-20 ], dtype = torch .bfloat16 )
770+ extreme_vals [1 , :4 ] = torch .tensor (
771+ [- 1e30 , - 1e-30 , - 1e20 , - 1e-20 ], dtype = torch .bfloat16
772+ )
773+ extreme_vals [2 , :4 ] = torch .tensor (
774+ [torch .finfo (torch .float32 ).max , torch .finfo (torch .float32 ).tiny , 1.0 , - 1.0 ],
775+ dtype = torch .bfloat16 ,
776+ )
777+ extreme_vals [3 , :4 ] = torch .tensor ([0.0 , 1e-40 , 1e40 , - 1e40 ], dtype = torch .bfloat16 )
778+
779+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
780+ extreme_vals , block_size = block_size , scaling_mode = scaling_mode
781+ )
782+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
783+ extreme_vals ,
784+ inner_block_size = block_size ,
785+ scaling_mode = scaling_mode .value .lower (),
786+ )
787+
788+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
789+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
790+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
791+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
792+
793+
794+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
795+ @pytest .mark .skipif (
796+ not is_sm_at_least_100 () and not is_MI350 (),
797+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
798+ )
799+ @pytest .mark .parametrize ("scaling_mode" , (ScaleCalculationMode .RCEIL ,))
800+ def test_triton_mxfp8_dim0_denormals_subnormals (scaling_mode ):
801+ """Test with denormal/subnormal values that might cause precision issues."""
802+ # Create values in the denormal range
803+ bf16_tiny = torch .finfo (torch .bfloat16 ).tiny
804+ f32_tiny = torch .finfo (torch .float32 ).tiny
805+ block_size = 32
806+
807+ denormal_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
808+
809+ # Fill first few elements with denormal values
810+ denormal_vals [0 , :4 ] = torch .tensor (
811+ [bf16_tiny , bf16_tiny * 0.5 , bf16_tiny * 0.1 , bf16_tiny * 2.0 ],
812+ dtype = torch .bfloat16 ,
813+ )
814+ denormal_vals [1 , :4 ] = torch .tensor (
815+ [f32_tiny , f32_tiny * 0.5 , f32_tiny * 0.1 , f32_tiny * 2.0 ], dtype = torch .bfloat16
816+ )
817+ denormal_vals [2 , :4 ] = torch .tensor (
818+ [- bf16_tiny , - bf16_tiny * 0.5 , - bf16_tiny * 0.1 , - bf16_tiny * 2.0 ],
819+ dtype = torch .bfloat16 ,
820+ )
821+ denormal_vals [3 , :4 ] = torch .tensor (
822+ [1e-40 , 1e-38 , 1e-36 , 1e-34 ], dtype = torch .bfloat16
823+ ) # Very small values
824+
825+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
826+ denormal_vals , block_size = block_size , scaling_mode = scaling_mode
827+ )
828+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
829+ denormal_vals ,
830+ inner_block_size = block_size ,
831+ scaling_mode = scaling_mode .value .lower (),
832+ )
833+
834+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
835+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
836+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
837+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
0 commit comments