Skip to content
Merged
85 changes: 84 additions & 1 deletion test/kernel/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from parameterized import parameterized

from torchao.utils import is_sm_at_least_90
from torchao.utils import is_sm_at_least_90, torch_version_at_least

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -96,5 +96,88 @@ def test_int_scaled_mm(self, device, dtype):
torch.testing.assert_allclose(out32_1, out32_2)


class TestIntScaledMatmulCPUPaths(unittest.TestCase):
"""
Tests for the CPU-specific paths inside _int_scaled_matmul_cpu.
Because the u8s8 VNNI branch is gated on runtime CPU feature detection,
CI machines are unlikely to exercise it naturally. We monkeypatch the
two helper functions so each branch can be tested on any machine.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit confusing, you mean you have to monkeypatch the _cpu_is_amx_tile_supported and _cpu_is_vnni_supported to run the test, I thought these have to reflect what the hardware is doing? what are the flags of the CI machines before monkeypatch? and did you only change the flag from True to False to test reference?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing. This is to ensure both paths are tested since the hardware for CI most likely does not have these ISA support. These paths work on all platforms while performance is the best with certain ISA support.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if hardware doesn't have instruction support:
_cpu_is_amx_tile_supported = False and _cpu_is_vnni_supported = False

and you set one of these to True, e.g. _cpu_is_amx_tile_supported = True, what happens?

can you provide more details on these in each of the test?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting these flags in different ways will lead to s8s8 or u8s8 paths (according to the rules we defined in torchao/kernel/intmm.py). Both paths work on all platforms because torch._int_mm for CPU calls oneDNN under the hood and oneDNN takes care of everything. This is for test only to ensure both paths work as expected in terms of functionality.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise only one path will be tested in CI. The purpose is to test both paths in CI to ensure they work.

"""

def _make_inputs(self, m=64, k=32, n=16, dtype=torch.bfloat16):
a = torch.randint(-128, 127, (m, k), dtype=torch.int8)
b = torch.randint(-128, 127, (k, n), dtype=torch.int8)
scales = torch.randn(m, 1, dtype=dtype)
return a, b, scales

def _reference(self, a, b, scales):
from torchao.kernel.intmm import safe_int_mm

return safe_int_mm(a, b).to(scales.dtype) * scales

@unittest.skipIf(not torch_version_at_least("2.12.0.dev"), "Need torch 2.12+")
def test_vnni_path_via_monkeypatch(self):
"""Force the u8s8 VNNI branch and verify against the reference result."""
import torchao.kernel.intmm as intmm_mod

a, b, scales = self._make_inputs()
expected = self._reference(a, b, scales)

orig_amx = intmm_mod._cpu_is_amx_tile_supported
orig_vnni = intmm_mod._cpu_is_vnni_supported
try:
# Simulate: no AMX, but VNNI present → u8s8 compensation path
intmm_mod._cpu_is_amx_tile_supported = lambda: False
intmm_mod._cpu_is_vnni_supported = lambda: True
result = intmm_mod._int_scaled_matmul_cpu(a, b, scales)
finally:
intmm_mod._cpu_is_amx_tile_supported = orig_amx
intmm_mod._cpu_is_vnni_supported = orig_vnni

torch.testing.assert_close(result, expected)

@unittest.skipIf(not torch_version_at_least("2.12.0.dev"), "Need torch 2.12+")
def test_amx_path_via_monkeypatch(self):
"""Force the s8s8 AMX/fallback branch and verify against the reference result."""
import torchao.kernel.intmm as intmm_mod

a, b, scales = self._make_inputs()
expected = self._reference(a, b, scales)

orig_amx = intmm_mod._cpu_is_amx_tile_supported
orig_vnni = intmm_mod._cpu_is_vnni_supported
try:
# Simulate: AMX present → s8s8 direct path (no compensation)
intmm_mod._cpu_is_amx_tile_supported = lambda: True
intmm_mod._cpu_is_vnni_supported = lambda: False
result = intmm_mod._int_scaled_matmul_cpu(a, b, scales)
finally:
intmm_mod._cpu_is_amx_tile_supported = orig_amx
intmm_mod._cpu_is_vnni_supported = orig_vnni

torch.testing.assert_close(result, expected)

@unittest.skipIf(not torch_version_at_least("2.12.0.dev"), "Need torch 2.12+")
def test_no_simd_path_via_monkeypatch(self):
"""Force the no-AMX/no-VNNI branch and verify against the reference result."""
import torchao.kernel.intmm as intmm_mod

a, b, scales = self._make_inputs()
expected = self._reference(a, b, scales)

orig_amx = intmm_mod._cpu_is_amx_tile_supported
orig_vnni = intmm_mod._cpu_is_vnni_supported
try:
# Simulate: neither AMX nor VNNI → s8s8 reference path
intmm_mod._cpu_is_amx_tile_supported = lambda: False
intmm_mod._cpu_is_vnni_supported = lambda: False
result = intmm_mod._int_scaled_matmul_cpu(a, b, scales)
finally:
intmm_mod._cpu_is_amx_tile_supported = orig_amx
intmm_mod._cpu_is_vnni_supported = orig_vnni

torch.testing.assert_close(result, expected)


if __name__ == "__main__":
unittest.main()
71 changes: 65 additions & 6 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,65 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return safe_int_mm(a, b)


def _cpu_is_amx_tile_supported() -> bool:
"""
Safely query AMX tile support, guarding against private API absence.
torch.cpu._is_amx_tile_supported / torch._C._cpu._is_amx_tile_supported are
private and may be missing in certain PyTorch builds or versions.
"""
if hasattr(torch._C._cpu, "_is_amx_tile_supported"):
return torch._C._cpu._is_amx_tile_supported()
elif hasattr(torch.cpu, "_is_amx_tile_supported"):
return torch.cpu._is_amx_tile_supported()
return False


def _cpu_is_vnni_supported() -> bool:
"""
Safely query AVX512_VNNI support, guarding against private API absence.
torch.cpu._is_vnni_supported / torch._C._cpu._is_vnni_supported are
private and may be missing in certain PyTorch builds or versions.
"""
if hasattr(torch._C._cpu, "_is_vnni_supported"):
return torch._C._cpu._is_vnni_supported()
elif hasattr(torch.cpu, "_is_vnni_supported"):
return torch.cpu._is_vnni_supported()
return False


def _int_scaled_matmul_cpu(
a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor
) -> torch.Tensor:
"""
CPU-optimized path for scaled integer matrix multiplication.
CPU prefers decomposed version to leverage the fusion capability of Inductor.
It goes to u8s8 or s8s8 path based on ISA support for hardware. The selection
is for performance only and both paths should work regardless of ISA support.

Args:
a (torch.Tensor): The first matrix to multiply (int8).
b (torch.Tensor): The second matrix to multiply (int8).
scales1 (torch.Tensor): The scaling factors, typically shape (M, 1).
A scalar-like shape (1, 1) is also supported and will broadcast
across all rows.

Returns:
torch.Tensor: The result of the scaled matrix multiplication.
"""
if (
not _cpu_is_amx_tile_supported() and _cpu_is_vnni_supported()
): # u8s8: Convert to uint8 to use AVX512_VNNI instructions for better performance
# on platforms with AVX512_VNNI support but without AMX.
a = (a.to(torch.int32) + 128).to(torch.uint8)
c = torch._int_mm(a, b)
comp = b.sum(dim=0, keepdim=True, dtype=torch.int32) * 128
c.sub_(comp)
return c.to(scales1.dtype) * scales1
else: # s8s8: Computation done with AMX or as the fallback.
c = torch._int_mm(a, b)
return c.to(scales1.dtype) * scales1


def int_scaled_matmul(
a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor
) -> torch.Tensor:
Expand All @@ -115,6 +174,8 @@ def int_scaled_matmul(
a (torch.Tensor): The first matrix to multiply.
b (torch.Tensor): The second matrix to multiply.
scales1 (torch.Tensor): The scaling factors for the rows of the result.
Expected shape is (M, 1). A scalar-like shape (1, 1) is also
supported and will broadcast across all rows.

Returns:
torch.Tensor: The result of the scaled matrix multiplication.
Expand All @@ -124,17 +185,15 @@ def int_scaled_matmul(
"""
M, K = a.shape
K, N = b.shape
assert scales1.dim() == 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure to check this before expand? does orignal op work with 1d scale

Copy link
Copy Markdown
Collaborator

@Xia-Weiwen Xia-Weiwen Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure to check this before expand? does orignal op work with 1d scale

It does not work with 1d scale. the assert 1 == scales1.size(1) below is not added by us and it assumes scales1 is 2d but previously it was only checked after that. So, we think it is better to check that at the beginning.

assert M == scales1.size(0) or scales1.numel() == 1
assert 1 == scales1.size(1)
assert scales1.is_contiguous()
scales1 = scales1.expand((M, N))
assert scales1.dim() == 2

if check_cpu_version(scales1.device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_cpu_version seems too vague I feel, maybe just something like is_device_type(scales1.device, "cpu")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_cpu_version seems too vague I feel, maybe just something like is_device_type(scales1.device, "cpu")

Thanks for the suggestion. However, this utility is not added by this PR. It is defined here:

def check_cpu_version(device, version="2.6.0"):
How about fixing it in another PR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it should be fixed in a separate PR

# CPU prefers decomposed version of int_scaled_matmul
# to leverage the fusion capability of Inductor
c = torch._int_mm(a, b)
return c.to(scales1.dtype) * scales1
return _int_scaled_matmul_cpu(a, b, scales1)

scales1 = scales1.expand((M, N))

if intmm_triton is not None and AUTOTUNER_ENABLE:
return torch.ops.torchao.int_scaled_matmul(a, b, scales1)
Expand Down
Loading