Skip to content

Commit 9c56093

Browse files
pytorchbotdigantdesaiGasoonjia
authored
Add Triton INT4 dense kernels with dequant prefill path for Qwen3.5 MoE (#19227)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #19188 by @digantdesai ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/digantdesai/51/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/digantdesai/51/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/digantdesai/50/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/digantdesai/51/orig @diff-train-skip-merge --------- Co-authored-by: Digant Desai <digantdesai@meta.com> Co-authored-by: Gasoonjia <gasoonjia@icloud.com>
1 parent e089ba4 commit 9c56093

6 files changed

Lines changed: 1029 additions & 0 deletions

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
419419
python -m executorch.examples.models.qwen3_5_moe.export \
420420
--prequantized "$LOCAL_MODEL_DIR" \
421421
--output-dir "${OUTPUT_DIR}" \
422+
--dense-prefill dequant \
422423
--moe-activation-dtype int8
423424
echo "::endgroup::"
424425

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Functional correctness tests for INT4 matmul and dequant Triton kernels.
10+
11+
Tests both int4_matmul (fused W4A16 GEMM) and dequant_w4_to_bf16 (weight
12+
dequantization) against eager PyTorch references. Uses 0.01 absolute
13+
tolerance to account for INT4 quantization noise and bf16 rounding.
14+
15+
Usage:
16+
python -m pytest backends/cuda/tests/test_int4_matmul.py -v
17+
"""
18+
19+
import unittest
20+
21+
import torch
22+
23+
from executorch.backends.cuda.triton.kernels.int4_matmul import (
24+
dequant_w4_to_bf16,
25+
int4_matmul,
26+
int4_matvec,
27+
)
28+
29+
ATOL = 0.01
30+
DEVICE = "cuda"
31+
32+
33+
def _quantize_simple(w_bf16, group_size):
34+
"""Quantize [N, K] bf16 weight to simple packed INT4 + per-group scales.
35+
36+
Returns:
37+
w_packed: [N, K//2] int8 — two INT4 values per byte
38+
w_scale: [N, K//group_size] bf16 — symmetric scales
39+
w_ref: [N, K] bf16 — dequantized reference matching kernel's computation
40+
"""
41+
N, K = w_bf16.shape
42+
w = w_bf16.float()
43+
w_grouped = w.reshape(N, K // group_size, group_size)
44+
scale = w_grouped.abs().amax(dim=-1, keepdim=True) / 7.0
45+
scale = scale.clamp(min=1e-10)
46+
int_data = (w_grouped / scale).round().clamp(-8, 7).to(torch.int8)
47+
# Kernel dequant: (uint4 - 8) * scale = int_data * scale
48+
scale_bf16 = scale.to(torch.bfloat16)
49+
w_ref = ((int_data.float()) * scale_bf16.float()).reshape(N, K).to(torch.bfloat16)
50+
scale_bf16 = scale_bf16.reshape(N, K // group_size)
51+
int_data = int_data.reshape(N, K)
52+
uint4 = (int_data + 8).to(torch.int16)
53+
packed = (uint4[:, 0::2] | (uint4[:, 1::2] << 4)).to(torch.int8)
54+
return packed.to(DEVICE), scale_bf16.to(DEVICE), w_ref.to(DEVICE)
55+
56+
57+
def _eager_int4_matmul(x, w_ref):
58+
"""Reference matmul: x @ w_ref.T in float32, cast to bf16."""
59+
return (x.float() @ w_ref.float().T).to(torch.bfloat16)
60+
61+
62+
class TestDequantW4ToBf16(unittest.TestCase):
63+
"""Tests for dequant_w4_to_bf16 Triton kernel."""
64+
65+
def _run_dequant(self, N, K, group_size):
66+
torch.manual_seed(42)
67+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
68+
packed, scale, w_ref = _quantize_simple(w, group_size)
69+
70+
out = dequant_w4_to_bf16(packed, scale, group_size)
71+
72+
self.assertEqual(out.shape, (N, K))
73+
self.assertEqual(out.dtype, torch.bfloat16)
74+
max_err = (out.float() - w_ref.float()).abs().max().item()
75+
self.assertLess(
76+
max_err, ATOL, f"dequant [{N}x{K}] gs={group_size}: max_err={max_err}"
77+
)
78+
79+
def test_square(self):
80+
self._run_dequant(256, 256, 32)
81+
82+
def test_tall(self):
83+
self._run_dequant(2048, 256, 32)
84+
85+
def test_wide(self):
86+
self._run_dequant(256, 2048, 128)
87+
88+
def test_production_qkv(self):
89+
self._run_dequant(2048, 2048, 128)
90+
91+
def test_production_shared_expert(self):
92+
self._run_dequant(1024, 2048, 128)
93+
94+
def test_group_size_32(self):
95+
self._run_dequant(512, 512, 32)
96+
97+
def test_group_size_128(self):
98+
self._run_dequant(512, 2048, 128)
99+
100+
def test_non_power_of_two_N(self):
101+
self._run_dequant(12352, 2048, 128)
102+
103+
def test_small(self):
104+
self._run_dequant(16, 64, 32)
105+
106+
107+
class TestInt4Matmul(unittest.TestCase):
108+
"""Tests for int4_matmul Triton kernel (fused W4A16 GEMM)."""
109+
110+
def _run_matmul(self, M, N, K, group_size):
111+
torch.manual_seed(42)
112+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
113+
packed, scale, w_ref = _quantize_simple(w, group_size)
114+
x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)
115+
116+
out = int4_matmul(x, packed, scale, group_size)
117+
ref = _eager_int4_matmul(x, w_ref)
118+
119+
self.assertEqual(out.shape, (M, N))
120+
self.assertEqual(out.dtype, torch.bfloat16)
121+
self.assertTrue(
122+
torch.allclose(out.float(), ref.float(), atol=ATOL, rtol=0.01),
123+
f"int4_matmul M={M} [{N}x{K}] gs={group_size}: "
124+
f"max_abs_err={(out.float() - ref.float()).abs().max().item():.4f}, "
125+
f"max_rel_err={((out.float() - ref.float()).abs() / ref.float().abs().clamp(min=1e-6)).max().item():.4f}",
126+
)
127+
128+
# --- Decode (M=1) ---
129+
def test_decode_square(self):
130+
self._run_matmul(1, 256, 256, 32)
131+
132+
def test_decode_qkv(self):
133+
self._run_matmul(1, 2048, 2048, 128)
134+
135+
def test_decode_kv_proj(self):
136+
self._run_matmul(1, 256, 2048, 128)
137+
138+
def test_decode_shared_expert(self):
139+
self._run_matmul(1, 1024, 2048, 128)
140+
141+
def test_decode_large_N(self):
142+
self._run_matmul(1, 12352, 2048, 128)
143+
144+
# --- Small prefill ---
145+
def test_prefill_4(self):
146+
self._run_matmul(4, 2048, 2048, 128)
147+
148+
def test_prefill_16(self):
149+
self._run_matmul(16, 2048, 2048, 128)
150+
151+
def test_prefill_64(self):
152+
self._run_matmul(64, 2048, 2048, 128)
153+
154+
# --- Large prefill ---
155+
def test_prefill_256(self):
156+
self._run_matmul(256, 2048, 2048, 128)
157+
158+
def test_prefill_1024(self):
159+
self._run_matmul(1024, 2048, 2048, 128)
160+
161+
def test_prefill_4095(self):
162+
self._run_matmul(4095, 2048, 2048, 128)
163+
164+
# --- Edge cases ---
165+
def test_group_size_32(self):
166+
self._run_matmul(4, 512, 512, 32)
167+
168+
def test_non_power_of_two_M(self):
169+
self._run_matmul(7, 256, 256, 32)
170+
171+
def test_non_power_of_two_N(self):
172+
self._run_matmul(4, 12352, 2048, 128)
173+
174+
def test_small(self):
175+
self._run_matmul(1, 16, 64, 32)
176+
177+
178+
class TestInt4Matvec(unittest.TestCase):
179+
"""Tests for int4_matvec Triton kernel (M=1 decode)."""
180+
181+
def _run_matvec(self, N, K, group_size):
182+
torch.manual_seed(42)
183+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
184+
packed, scale, w_ref = _quantize_simple(w, group_size)
185+
x = torch.randn(K, dtype=torch.bfloat16, device=DEVICE)
186+
187+
out = int4_matvec(x.unsqueeze(0), packed, scale, group_size)
188+
ref = int4_matmul(x.unsqueeze(0), packed, scale, group_size)
189+
190+
self.assertEqual(out.shape, (1, N))
191+
self.assertEqual(out.dtype, torch.bfloat16)
192+
# atol=1.0 for large accumulation across K, rtol=0.01 for relative
193+
self.assertTrue(
194+
torch.allclose(out.float(), ref.float(), atol=1.0, rtol=0.01),
195+
f"int4_matvec [{N}x{K}] gs={group_size}: "
196+
f"max_err={(out.float() - ref.float()).abs().max().item():.4f}, "
197+
f"max_rel={((out.float()-ref.float()).abs()/(ref.float().abs().clamp(min=0.1))).max().item():.4f}",
198+
)
199+
200+
def test_qkv_proj(self):
201+
self._run_matvec(2048, 2048, 128)
202+
203+
def test_kv_proj(self):
204+
self._run_matvec(256, 2048, 128)
205+
206+
def test_shared_expert(self):
207+
self._run_matvec(1024, 2048, 128)
208+
209+
def test_large_N(self):
210+
self._run_matvec(12352, 2048, 128)
211+
212+
def test_group_size_32(self):
213+
self._run_matvec(512, 512, 32)
214+
215+
def test_small(self):
216+
self._run_matvec(16, 64, 32)
217+
218+
def test_matches_int4_matmul(self):
219+
"""Matvec output matches int4_matmul at M=1."""
220+
torch.manual_seed(42)
221+
N, K, gs = 2048, 2048, 128
222+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
223+
packed, scale, _ = _quantize_simple(w, gs)
224+
x = torch.randn(1, K, dtype=torch.bfloat16, device=DEVICE)
225+
226+
out_mv = int4_matvec(x, packed, scale, gs)
227+
out_mm = int4_matmul(x, packed, scale, gs)
228+
229+
self.assertTrue(
230+
torch.allclose(out_mv.float(), out_mm.float(), atol=1.0, rtol=0.01),
231+
f"matvec vs matmul: max_err={(out_mv.float() - out_mm.float()).abs().max().item():.4f}",
232+
)
233+
234+
235+
class TestDequantThenMatmul(unittest.TestCase):
236+
"""Tests that dequant + F.linear matches int4_matmul (both paths should agree)."""
237+
238+
def _run(self, M, N, K, group_size):
239+
torch.manual_seed(42)
240+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
241+
packed, scale, w_ref = _quantize_simple(w, group_size)
242+
x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)
243+
244+
# Path A: fused int4_matmul
245+
out_fused = int4_matmul(x, packed, scale, group_size)
246+
247+
# Path B: dequant + F.linear
248+
w_bf16 = dequant_w4_to_bf16(packed, scale, group_size)
249+
out_dequant = torch.nn.functional.linear(x, w_bf16)
250+
251+
self.assertTrue(
252+
torch.allclose(
253+
out_fused.float(), out_dequant.float(), atol=ATOL, rtol=0.01
254+
),
255+
f"fused vs dequant M={M} [{N}x{K}]: "
256+
f"max_abs_err={(out_fused.float() - out_dequant.float()).abs().max().item():.4f}",
257+
)
258+
259+
def test_decode(self):
260+
self._run(1, 2048, 2048, 128)
261+
262+
def test_prefill_short(self):
263+
self._run(64, 2048, 2048, 128)
264+
265+
def test_prefill_long(self):
266+
self._run(1024, 2048, 2048, 128)
267+
268+
def test_large_N(self):
269+
self._run(4, 12352, 2048, 128)
270+
271+
272+
if __name__ == "__main__":
273+
unittest.main()

backends/cuda/triton/kernels/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,20 @@
1010
fused_moe_batched_gemm,
1111
moe_align_block_size,
1212
)
13+
14+
from executorch.backends.cuda.triton.kernels.int4_matmul import (
15+
dequant_w4_to_bf16,
16+
int4_matvec,
17+
)
1318
from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
1419
from executorch.backends.cuda.triton.kernels.topk import topk
1520

1621
__all__ = [
22+
"dequant_w4_to_bf16",
1723
"fused_moe",
1824
"fused_moe_batched",
1925
"fused_moe_batched_gemm",
26+
"int4_matvec",
2027
"moe_align_block_size",
2128
"sdpa",
2229
"sdpa_decode_splitk",

0 commit comments

Comments
 (0)