-
Notifications
You must be signed in to change notification settings - Fork 469
[X86] intmm: Use u8s8 when only support avx512-vnni #4103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
956bde8
2c7682e
a875414
feeb805
889dd92
9709e79
31f41d1
50c30ce
d185af5
744ed09
8219fe6
6462a41
aa61c38
78170f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -105,6 +105,65 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |||
| return safe_int_mm(a, b) | ||||
|
|
||||
|
|
||||
| def _cpu_is_amx_tile_supported() -> bool: | ||||
| """ | ||||
| Safely query AMX tile support, guarding against private API absence. | ||||
| torch.cpu._is_amx_tile_supported / torch._C._cpu._is_amx_tile_supported are | ||||
| private and may be missing in certain PyTorch builds or versions. | ||||
| """ | ||||
| if hasattr(torch._C._cpu, "_is_amx_tile_supported"): | ||||
| return torch._C._cpu._is_amx_tile_supported() | ||||
| elif hasattr(torch.cpu, "_is_amx_tile_supported"): | ||||
| return torch.cpu._is_amx_tile_supported() | ||||
| return False | ||||
|
|
||||
|
|
||||
| def _cpu_is_vnni_supported() -> bool: | ||||
| """ | ||||
| Safely query AVX512_VNNI support, guarding against private API absence. | ||||
| torch.cpu._is_vnni_supported / torch._C._cpu._is_vnni_supported are | ||||
| private and may be missing in certain PyTorch builds or versions. | ||||
| """ | ||||
| if hasattr(torch._C._cpu, "_is_vnni_supported"): | ||||
| return torch._C._cpu._is_vnni_supported() | ||||
| elif hasattr(torch.cpu, "_is_vnni_supported"): | ||||
| return torch.cpu._is_vnni_supported() | ||||
| return False | ||||
|
|
||||
|
|
||||
| def _int_scaled_matmul_cpu( | ||||
| a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor | ||||
| ) -> torch.Tensor: | ||||
| """ | ||||
| CPU-optimized path for scaled integer matrix multiplication. | ||||
| CPU prefers decomposed version to leverage the fusion capability of Inductor. | ||||
| It goes to u8s8 or s8s8 path based on ISA support for hardware. The selection | ||||
| is for performance only and both paths should work regardless of ISA support. | ||||
|
|
||||
| Args: | ||||
| a (torch.Tensor): The first matrix to multiply (int8). | ||||
| b (torch.Tensor): The second matrix to multiply (int8). | ||||
| scales1 (torch.Tensor): The scaling factors, typically shape (M, 1). | ||||
| A scalar-like shape (1, 1) is also supported and will broadcast | ||||
| across all rows. | ||||
|
|
||||
| Returns: | ||||
| torch.Tensor: The result of the scaled matrix multiplication. | ||||
| """ | ||||
| if ( | ||||
| not _cpu_is_amx_tile_supported() and _cpu_is_vnni_supported() | ||||
| ): # u8s8: Convert to uint8 to use AVX512_VNNI instructions for better performance | ||||
| # on platforms with AVX512_VNNI support but without AMX. | ||||
| a = (a.to(torch.int32) + 128).to(torch.uint8) | ||||
| c = torch._int_mm(a, b) | ||||
| comp = b.sum(dim=0, keepdim=True, dtype=torch.int32) * 128 | ||||
| c.sub_(comp) | ||||
| return c.to(scales1.dtype) * scales1 | ||||
| else: # s8s8: Computation done with AMX or as the fallback. | ||||
| c = torch._int_mm(a, b) | ||||
| return c.to(scales1.dtype) * scales1 | ||||
|
|
||||
|
|
||||
| def int_scaled_matmul( | ||||
| a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor | ||||
| ) -> torch.Tensor: | ||||
|
|
@@ -115,6 +174,8 @@ def int_scaled_matmul( | |||
| a (torch.Tensor): The first matrix to multiply. | ||||
| b (torch.Tensor): The second matrix to multiply. | ||||
| scales1 (torch.Tensor): The scaling factors for the rows of the result. | ||||
| Expected shape is (M, 1). A scalar-like shape (1, 1) is also | ||||
| supported and will broadcast across all rows. | ||||
|
|
||||
| Returns: | ||||
| torch.Tensor: The result of the scaled matrix multiplication. | ||||
|
|
@@ -124,17 +185,15 @@ def int_scaled_matmul( | |||
| """ | ||||
| M, K = a.shape | ||||
| K, N = b.shape | ||||
| assert scales1.dim() == 2 | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you sure to check this before expand? does orignal op work with 1d scale
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It does not work with 1d scale. the |
||||
| assert M == scales1.size(0) or scales1.numel() == 1 | ||||
| assert 1 == scales1.size(1) | ||||
| assert scales1.is_contiguous() | ||||
| scales1 = scales1.expand((M, N)) | ||||
| assert scales1.dim() == 2 | ||||
|
|
||||
| if check_cpu_version(scales1.device): | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. However, this utility is not added by this PR. It is defined here: Line 1242 in 9051a2f
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it should be fixed in a separate PR |
||||
| # CPU prefers decomposed version of int_scaled_matmul | ||||
| # to leverage the fusion capability of Inductor | ||||
| c = torch._int_mm(a, b) | ||||
| return c.to(scales1.dtype) * scales1 | ||||
| return _int_scaled_matmul_cpu(a, b, scales1) | ||||
|
|
||||
| scales1 = scales1.expand((M, N)) | ||||
|
|
||||
| if intmm_triton is not None and AUTOTUNER_ENABLE: | ||||
| return torch.ops.torchao.int_scaled_matmul(a, b, scales1) | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems a bit confusing, you mean you have to monkeypatch the
_cpu_is_amx_tile_supportedand_cpu_is_vnni_supportedto run the test, I thought these have to reflect what the hardware is doing? what are the flags of the CI machines before monkeypatch? and did you only change the flag from True to False to test reference?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing. This is to ensure both paths are tested since the hardware for CI most likely does not have these ISA support. These paths work on all platforms while performance is the best with certain ISA support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so if hardware doesn't have instruction support:
_cpu_is_amx_tile_supported = False and _cpu_is_vnni_supported = False
and you set one of these to True, e.g. _cpu_is_amx_tile_supported = True, what happens?
can you provide more details on these in each of the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting these flags in different ways will lead to s8s8 or u8s8 paths (according to the rules we defined in torchao/kernel/intmm.py). Both paths work on all platforms because
torch._int_mmfor CPU calls oneDNN under the hood and oneDNN takes care of everything. This is for test only to ensure both paths work as expected in terms of functionality.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise only one path will be tested in CI. The purpose is to test both paths in CI to ensure they work.