Skip to content

Commit 969310d

Browse files
committed
[mx] Add missing parameter in mxfp8 kernel
When triton_to_mxfp8_dim1() was called from _to_mxfp8_dim1_kernel_wrapper(), the scale_calculation_mode parameter was not passed, resulting in incorrect default value. Signed-off-by: Ula Golowicz <urszula.golowicz@intel.com>
1 parent 2a8714f commit 969310d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchao/prototype/mx_formats/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def _to_mxfp8_dim1_kernel_wrapper(
164164
ScaleCalculationMode.FLOOR,
165165
ScaleCalculationMode.RCEIL,
166166
)
167-
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
167+
a_data, a_scale = triton_to_mxfp8_dim1(
168+
a, block_size, scale_calculation_mode.value
169+
)
168170
elif cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
169171
assert scale_calculation_mode in (
170172
ScaleCalculationMode.FLOOR,

0 commit comments

Comments
 (0)