[X86] intmm: Use u8s8 when only support avx512-vnni#4103
[X86] intmm: Use u8s8 when only support avx512-vnni#4103Xia-Weiwen merged 14 commits intopytorch:mainfrom
Conversation
Convert from int8 to uint8 and compute with u8s8 when platforms only support avx512-vnni. Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4103
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (8 Unrelated Failures)As of commit 78170f6 with merge base 1087d59 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Signed-off-by: Cui, Lily <lily.cui@intel.com>
For platforms with AVX512_VNNI support but without AMX, we convert to u8s8 to use AVX512_VNNI instructions for better performance. For other platforms, s8s8 computation is done with AMX or reference implementation. Signed-off-by: Cui, Lily <lily.cui@intel.com>
There was a problem hiding this comment.
Pull request overview
Adds a CPU-specific implementation for int_scaled_matmul that switches to a u8*s8 (uint8 x int8) compute path when AVX512-VNNI is available but AMX tile is not, aiming to improve performance on those CPUs.
Changes:
- Introduce
_int_scaled_matmul_cpuhelper that conditionally uses a u8*s8 path with compensation fora’s zero-point shift. - Route the CPU code path in
int_scaled_matmulthrough the new helper (and avoid expandingscales1on CPU).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@claude review |
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
|
@claude I have updated the codes. Review again. |
|
@claude codes are updated per your (and copilot's) comments. Review again. |
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Don't check zero point when SYMMETRIC because zero point may not be None, but rather a tensor with values of 0. Signed-off-by: Cui, Lily <lily.cui@intel.com>
Move to another PR. Signed-off-by: Cui, Lily <lily.cui@intel.com>
torchao/kernel/intmm.py
Outdated
| assert 1 == scales1.size(1) | ||
| assert scales1.is_contiguous() | ||
| scales1 = scales1.expand((M, N)) | ||
| assert scales1.dim() == 2 |
There was a problem hiding this comment.
should the assert be moved as well, previusly this is after expand
There was a problem hiding this comment.
The CPU also needs to check that scales1.dim() == 2, otherwise calculation with 3D scales1 will also be incorrect. And if scales1 is 1D, 'assert 1 == scales1.size(1)' will cause an error, so I think it's better to move the check for scales1.dim() == 2 at the beginning.
| Tests for the CPU-specific paths inside _int_scaled_matmul_cpu. | ||
| Because the u8s8 VNNI branch is gated on runtime CPU feature detection, | ||
| CI machines are unlikely to exercise it naturally. We monkeypatch the | ||
| two helper functions so each branch can be tested on any machine. |
There was a problem hiding this comment.
this seems a bit confusing, you mean you have to monkeypatch the _cpu_is_amx_tile_supported and _cpu_is_vnni_supported to run the test, I thought these have to reflect what the hardware is doing? what are the flags of the CI machines before monkeypatch? and did you only change the flag from True to False to test reference?
There was a problem hiding this comment.
Thanks for reviewing. This is to ensure both paths are tested since the hardware for CI most likely does not have these ISA support. These paths work on all platforms while performance is the best with certain ISA support.
There was a problem hiding this comment.
so if hardware doesn't have instruction support:
_cpu_is_amx_tile_supported = False and _cpu_is_vnni_supported = False
and you set one of these to True, e.g. _cpu_is_amx_tile_supported = True, what happens?
can you provide more details on these in each of the test?
There was a problem hiding this comment.
Setting these flags in different ways will lead to s8s8 or u8s8 paths (according to the rules we defined in torchao/kernel/intmm.py). Both paths work on all platforms because torch._int_mm for CPU calls oneDNN under the hood and oneDNN takes care of everything. This is for test only to ensure both paths work as expected in terms of functionality.
There was a problem hiding this comment.
Otherwise only one path will be tested in CI. The purpose is to test both paths in CI to ensure they work.
Signed-off-by: Cui, Lily <lily.cui@intel.com>
|
@jerryzh168 Could you review again? |
| scales1 = scales1.expand((M, N)) | ||
| assert scales1.dim() == 2 | ||
|
|
||
| if check_cpu_version(scales1.device): |
There was a problem hiding this comment.
check_cpu_version seems too vague I feel, maybe just something like is_device_type(scales1.device, "cpu")
There was a problem hiding this comment.
check_cpu_versionseems too vague I feel, maybe just something likeis_device_type(scales1.device, "cpu")
Thanks for the suggestion. However, this utility is not added by this PR. It is defined here:
Line 1242 in 9051a2f
There was a problem hiding this comment.
yeah it should be fixed in a separate PR
| """ | ||
| M, K = a.shape | ||
| K, N = b.shape | ||
| assert scales1.dim() == 2 |
There was a problem hiding this comment.
are you sure to check this before expand? does orignal op work with 1d scale
There was a problem hiding this comment.
are you sure to check this before expand? does orignal op work with 1d scale
It does not work with 1d scale. the assert 1 == scales1.size(1) below is not added by us and it assumes scales1 is 2d but previously it was only checked after that. So, we think it is better to check that at the beginning.
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
|
CI failures are unrelated. |
Use u8s8 matmul for intmm on X86 CPU when only support avx512-vnni for better performance.