Skip to content

Commit 76ad054

Browse files
committed
Add GPTQ support for block quantization
1 parent 464d000 commit 76ad054

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

src/llmcompressor/modifiers/gptq/gptq_quantize.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,16 @@ def quantize_weight(
226226
altered_qargs,
227227
global_scale=global_scale,
228228
)
229+
elif strategy == QuantizationStrategy.BLOCK:
230+
block_width = quant_args.block_structure[1]
231+
block_column_idx = (i1 + i) // block_width
232+
q = fake_quantize(
233+
q.unsqueeze(1),
234+
scale[:, block_column_idx : block_column_idx + 1],
235+
zero_point[:, block_column_idx : block_column_idx + 1],
236+
quant_args,
237+
global_scale=global_scale,
238+
).squeeze(1)
229239
else:
230240
raise ValueError(
231241
f"Quantization strategy is not supported for GPTQ: {strategy}"
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
3+
4+
from llmcompressor.modifiers.gptq.gptq_quantize import (
5+
make_empty_hessian,
6+
quantize_weight,
7+
)
8+
9+
10+
@torch.no_grad()
11+
def test_quantize_weight_supports_block_strategy():
12+
module = torch.nn.Linear(7, 5, bias=False)
13+
quant_args = QuantizationArgs(
14+
num_bits=8,
15+
symmetric=True,
16+
strategy="block",
17+
block_structure=[2, 4],
18+
)
19+
module.quantization_scheme = QuantizationScheme(
20+
targets=["Linear"], weights=quant_args
21+
)
22+
23+
hessian = make_empty_hessian(module)
24+
hessian += torch.eye(hessian.shape[0], dtype=hessian.dtype, device=hessian.device)
25+
26+
loss, q_param_dict = quantize_weight(
27+
module=module,
28+
quant_args=quant_args,
29+
hessian=hessian,
30+
blocksize=3,
31+
)
32+
33+
assert loss >= 0
34+
assert q_param_dict["weight"].shape == module.weight.shape
35+
assert q_param_dict["weight_scale"].shape == (3, 2)
36+
assert q_param_dict["weight_zero_point"].shape == (3, 2)
37+
assert "weight_g_idx" not in q_param_dict

0 commit comments

Comments
 (0)