Skip to content

[X86] intmm: Use u8s8 when only support avx512-vnni#4103

Merged
Xia-Weiwen merged 14 commits intopytorch:mainfrom
cyxlily:u8s8
Mar 30, 2026
Merged

[X86] intmm: Use u8s8 when only support avx512-vnni#4103
Xia-Weiwen merged 14 commits intopytorch:mainfrom
cyxlily:u8s8

Conversation

@cyxlily
Copy link
Copy Markdown
Contributor

@cyxlily cyxlily commented Mar 18, 2026

Use u8s8 matmul for intmm on X86 CPU when only support avx512-vnni for better performance.

cyxlily added 2 commits March 17, 2026 18:44
Convert from int8 to uint8 and compute with u8s8
when platforms only support avx512-vnni.

Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 18, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4103

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (8 Unrelated Failures)

As of commit 78170f6 with merge base 1087d59 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 18, 2026
@Xia-Weiwen Xia-Weiwen added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Mar 18, 2026
cyxlily added 2 commits March 17, 2026 23:51
Signed-off-by: Cui, Lily <lily.cui@intel.com>
For platforms with AVX512_VNNI support but without AMX,
we convert to u8s8 to use AVX512_VNNI instructions for better performance.
For other platforms,
s8s8 computation is done with AMX or reference implementation.

Signed-off-by: Cui, Lily <lily.cui@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a CPU-specific implementation for int_scaled_matmul that switches to a u8*s8 (uint8 x int8) compute path when AVX512-VNNI is available but AMX tile is not, aiming to improve performance on those CPUs.

Changes:

  • Introduce _int_scaled_matmul_cpu helper that conditionally uses a u8*s8 path with compensation for a’s zero-point shift.
  • Route the CPU code path in int_scaled_matmul through the new helper (and avoid expanding scales1 on CPU).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

@claude review

@claude

This comment was marked as resolved.

cyxlily added 2 commits March 20, 2026 00:20
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
@cyxlily
Copy link
Copy Markdown
Contributor Author

cyxlily commented Mar 20, 2026

@claude I have updated the codes. Review again.

@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

@claude codes are updated per your (and copilot's) comments. Review again.

@claude

This comment was marked as resolved.

@Xia-Weiwen Xia-Weiwen changed the title Use u8s8 when only support avx512-vnni [X86] intmm: Use u8s8 when only support avx512-vnni Mar 20, 2026
cyxlily added 3 commits March 22, 2026 23:52
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Don't check zero point when SYMMETRIC because zero point
may not be None, but rather a tensor with values of 0.

Signed-off-by: Cui, Lily <lily.cui@intel.com>
Move to another PR.

Signed-off-by: Cui, Lily <lily.cui@intel.com>
assert 1 == scales1.size(1)
assert scales1.is_contiguous()
scales1 = scales1.expand((M, N))
assert scales1.dim() == 2
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the assert be moved as well, previusly this is after expand

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CPU also needs to check that scales1.dim() == 2, otherwise calculation with 3D scales1 will also be incorrect. And if scales1 is 1D, 'assert 1 == scales1.size(1)' will cause an error, so I think it's better to move the check for scales1.dim() == 2 at the beginning.

Tests for the CPU-specific paths inside _int_scaled_matmul_cpu.
Because the u8s8 VNNI branch is gated on runtime CPU feature detection,
CI machines are unlikely to exercise it naturally. We monkeypatch the
two helper functions so each branch can be tested on any machine.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit confusing, you mean you have to monkeypatch the _cpu_is_amx_tile_supported and _cpu_is_vnni_supported to run the test, I thought these have to reflect what the hardware is doing? what are the flags of the CI machines before monkeypatch? and did you only change the flag from True to False to test reference?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing. This is to ensure both paths are tested since the hardware for CI most likely does not have these ISA support. These paths work on all platforms while performance is the best with certain ISA support.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if hardware doesn't have instruction support:
_cpu_is_amx_tile_supported = False and _cpu_is_vnni_supported = False

and you set one of these to True, e.g. _cpu_is_amx_tile_supported = True, what happens?

can you provide more details on these in each of the test?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting these flags in different ways will lead to s8s8 or u8s8 paths (according to the rules we defined in torchao/kernel/intmm.py). Both paths work on all platforms because torch._int_mm for CPU calls oneDNN under the hood and oneDNN takes care of everything. This is for test only to ensure both paths work as expected in terms of functionality.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise only one path will be tested in CI. The purpose is to test both paths in CI to ensure they work.

Signed-off-by: Cui, Lily <lily.cui@intel.com>
@cyxlily cyxlily requested a review from jerryzh168 March 25, 2026 05:17
@cyxlily
Copy link
Copy Markdown
Contributor Author

cyxlily commented Mar 25, 2026

@jerryzh168 Could you review again?

scales1 = scales1.expand((M, N))
assert scales1.dim() == 2

if check_cpu_version(scales1.device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_cpu_version seems too vague I feel, maybe just something like is_device_type(scales1.device, "cpu")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_cpu_version seems too vague I feel, maybe just something like is_device_type(scales1.device, "cpu")

Thanks for the suggestion. However, this utility is not added by this PR. It is defined here:

def check_cpu_version(device, version="2.6.0"):
How about fixing it in another PR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it should be fixed in a separate PR

"""
M, K = a.shape
K, N = b.shape
assert scales1.dim() == 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure to check this before expand? does orignal op work with 1d scale

Copy link
Copy Markdown
Collaborator

@Xia-Weiwen Xia-Weiwen Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure to check this before expand? does orignal op work with 1d scale

It does not work with 1d scale. the assert 1 == scales1.size(1) below is not added by us and it assumes scales1 is 2d but previously it was only checked after that. So, we think it is better to check that at the beginning.

cyxlily added 2 commits March 26, 2026 19:05
Signed-off-by: Cui, Lily <lily.cui@intel.com>
Signed-off-by: Cui, Lily <lily.cui@intel.com>
@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

CI failures are unrelated.

@Xia-Weiwen Xia-Weiwen merged commit d5814ae into pytorch:main Mar 30, 2026
11 of 19 checks passed
@cyxlily cyxlily deleted the u8s8 branch March 30, 2026 06:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants