Skip to content

Matmul kernel preference support for Int8Tensor#3558

Open
namgyu-youn wants to merge 22 commits intopytorch:mainfrom
namgyu-youn:int8-triton
Open

Matmul kernel preference support for Int8Tensor#3558
namgyu-youn wants to merge 22 commits intopytorch:mainfrom
namgyu-youn:int8-triton

Conversation

@namgyu-youn
Copy link
Copy Markdown
Contributor

@namgyu-youn namgyu-youn commented Dec 31, 2025

Summary:
Add kernel routing support (kernel_preference) for `Int8Tensor — "auto", "pytorch", and "triton"

Motivation:
torch._int_mm (INT8 MatMul kernel in pytorch internal) requires M (batch size) > 16, which failed CUDA graph capture like vLLM — https://gist.github.com/vkuzo/5bf389079442bb9851ef315cdcb797b4.

For better vLLM integration and performance, I would like to support alternative INT8 MatMul kernel support like the Triton-based scaled_int8_mm. This kernel was implemented by @gau-nernst

Example:

# Auto routing (dedault)
config = Int8DynamicActivationInt8WeightConfig(
    kernel_preference="auto"  # default, select based on setup

# Use ao kernel (inside pytorch)
config = Int8DynamicActivationInt8WeightConfig(
    kernel_preference="pytorch"  # Use ao kernel (inside pytorch)
)

# Use custom triton kernel
config = Int8DynamicActivationInt8WeightConfig(
    granularity=PerRow(),  # Note: support PerRow only
    kernel_preference="triton"  # Use Triton kernel
)

# Common quantization flow
quantize_(model, config)

Test plan:

pytest -sv test/quantization/quantize_/workflows/int8/test_int8_tensor.py

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Dec 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3558

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 31, 2025
@namgyu-youn namgyu-youn changed the title feat: add INT8 scaled matmul Triton kernel feat: INT8 scaled matmul Triton kernel Dec 31, 2025
@namgyu-youn
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: new feature"

@pytorch-bot pytorch-bot bot added the topic: new feature Use this tag if this PR adds a new feature label Dec 31, 2025
@namgyu-youn namgyu-youn changed the title feat: INT8 scaled matmul Triton kernel [Triton] INT8 scaled matmul kernel Jan 4, 2026
@namgyu-youn
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: improvement"

@pytorch-bot pytorch-bot bot added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jan 4, 2026
@vkuzo
Copy link
Copy Markdown
Contributor

vkuzo commented Jan 6, 2026

check out https://github.com/pytorch/ao/blob/main/torchao/prototype/quantized_training/int8_mm.py, maybe we can improve the existing one instead?

@namgyu-youn
Copy link
Copy Markdown
Contributor Author

namgyu-youn commented Jan 6, 2026

check out https://github.com/pytorch/ao/blob/main/torchao/prototype/quantized_training/int8_mm.py, maybe we can improve the existing one instead?

Thanks I didn't know it before. Updated PR entirely because improvement after promotion looks better. to use that triton kernel, could you look into it?

@namgyu-youn namgyu-youn changed the title [Triton] INT8 scaled matmul kernel [Triton] Promote INT8 scaled matmul kernel Jan 7, 2026
@namgyu-youn namgyu-youn marked this pull request as draft January 18, 2026 07:33
@namgyu-youn namgyu-youn marked this pull request as ready for review January 18, 2026 09:10
@namgyu-youn namgyu-youn changed the title [Triton] Promote INT8 scaled matmul kernel Add support for custom matmul kernel routing in Int8Tensor Jan 18, 2026
@namgyu-youn
Copy link
Copy Markdown
Contributor Author

@vkuzo could you please look again this updated PR? also cc @jerryzh168

Non-Tensor Attributes:
granularity: the granularity for quantization (e.g., PerRow(), PerTensor())
act_quant_kwargs: flags for dynamic activation quantization
mm_config: Matmul kernel to use - "pytorch" (default) or "triton"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be better to follow this design:

kernel_preference (KernelPreference): the preference for quantize, mm etc. kernel to use,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I didn't follow the Float8Tensor pattern. The differences can be summarized as:

  • Float8Tensor: (1) default is None, (2) uses mm_config and kernel_preference for kernel selection
  • Int8Tensor: (1) default is 'pytorch' (preserves existing behavior), (2) uses mm_config only

In my opinion, (1) having a default helps with simpler logic, and (2) only one config is needed for kernel selection. Could you please check again?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I don't understand why Float8Tensor needs 3 configs to route the kernel.

It seems to work like: 1) kernel_preference is defined by the user, 2) kernel_choice is defined at runtime, 3) mm_config is defined after passing runtime? Is it possible to make it simpler?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for what you are adding here, kernel_preference is the existing abstraction we have, so I'm recommending to use that to stay consistent across the codebase. mm_config shouldn't be related to this, and kernel_choice should not matter because it's not in the BC surface.

Copy link
Copy Markdown
Contributor Author

@namgyu-youn namgyu-youn Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, I will update to use https://github.com/pytorch/ao/blob/main/torchao/quantization/quantize_/common/kernel_preference.py then. Thanks for pointing it out.

Copy link
Copy Markdown
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's be consistent with Float8Tensor for this logic

@namgyu-youn namgyu-youn requested a review from vkuzo January 20, 2026 17:08
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
).to(output_dtype)
y = y_dot_scaled * w_scales.flatten()
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AUTO should always be supported

set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor
kernel_preference (KernelPreference): Kernel preference for matmul operations. TORCH uses int_scaled_matmul,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put AUTO first

granularity: Granularity = PerRow()
set_inductor_config: bool = True
version: int = 1
kernel_preference: KernelPreference = KernelPreference.TORCH
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to AUTO

@namgyu-youn namgyu-youn requested a review from vkuzo January 30, 2026 04:21
@namgyu-youn namgyu-youn changed the title Add support for custom matmul kernel routing in Int8Tensor Matmul kernel routing support for Int8Tensor Jan 30, 2026
@namgyu-youn
Copy link
Copy Markdown
Contributor Author

namgyu-youn commented Feb 3, 2026

It seems that CI failure was caused by uninstalled Triton. Should I (1) add Triton to a dependency or (2) make unit test to skip if Triton is uninstalled?

@namgyu-youn
Copy link
Copy Markdown
Contributor Author

Just updated unit test to skip if Triton is not installed. This would be safer for CI.

@namgyu-youn
Copy link
Copy Markdown
Contributor Author

@vkuzo could you please look again this?

@namgyu-youn
Copy link
Copy Markdown
Contributor Author

namgyu-youn commented Feb 14, 2026

Updated to conditional import Triton torchao/kernel/__init__.py, CI should be green now I believe.

@namgyu-youn
Copy link
Copy Markdown
Contributor Author

@vkuzo could you please look again this?

@namgyu-youn namgyu-youn changed the title Matmul kernel routing support for Int8Tensor Matmul kernel preference support for Int8Tensor Mar 2, 2026
Comment on lines +318 to +321
# Verify correctness by comparing with reference
output_ref = torch.nn.functional.linear(input_tensor, weight_ref)
sqnr = compute_error(output_ref, output)
self.assertGreater(sqnr, 20, f"SQNR is too low: {sqnr} dB (expected > 20 dB)")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also check numerical consistency between kernel preferences I think, can you add a new test similar to

def test_kernel_preference_numerical_equivalence(self, granularity, sizes):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, added test_kernel_preference_numerical_equivalence to check numerical consistency.

kernel_choice = (
"triton" if tmp.device.type == "cuda" and is_rowwise else "torch"
)
elif weight_tensor.kernel_preference == KernelPreference.TRITON:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if triton only supports rowwise we should do a check here and error out when user is not using rowwise here I think

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not how this op has the same numerics as the other int8 mm ops actually, we should check

Copy link
Copy Markdown
Contributor Author

@namgyu-youn namgyu-youn Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "numerics", not hardware behavior? Can we check under test_kernel_preference_numerical_equivalence?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we can check numerics with test_kernel_preference_numerical_equivalence

@namgyu-youn namgyu-youn requested a review from jerryzh168 March 24, 2026 22:40
@namgyu-youn
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "module: inference"

@pytorch-bot pytorch-bot bot added the module: inference quantize_ api inference flow label Mar 25, 2026
@namgyu-youn
Copy link
Copy Markdown
Contributor Author

@jerryzh168 could you please review this PR again?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants