File tree Expand file tree Collapse file tree 2 files changed +47
-0
lines changed
src/llmcompressor/modifiers/gptq
tests/llmcompressor/modifiers/gptq Expand file tree Collapse file tree 2 files changed +47
-0
lines changed Original file line number Diff line number Diff line change @@ -226,6 +226,16 @@ def quantize_weight(
226226 altered_qargs ,
227227 global_scale = global_scale ,
228228 )
229+ elif strategy == QuantizationStrategy .BLOCK :
230+ block_width = quant_args .block_structure [1 ]
231+ block_column_idx = (i1 + i ) // block_width
232+ q = fake_quantize (
233+ q .unsqueeze (1 ),
234+ scale [:, block_column_idx : block_column_idx + 1 ],
235+ zero_point [:, block_column_idx : block_column_idx + 1 ],
236+ quant_args ,
237+ global_scale = global_scale ,
238+ ).squeeze (1 )
229239 else :
230240 raise ValueError (
231241 f"Quantization strategy is not supported for GPTQ: { strategy } "
Original file line number Diff line number Diff line change 1+ import torch
2+ from compressed_tensors .quantization import QuantizationArgs , QuantizationScheme
3+
4+ from llmcompressor .modifiers .gptq .gptq_quantize import (
5+ make_empty_hessian ,
6+ quantize_weight ,
7+ )
8+
9+
10+ @torch .no_grad ()
11+ def test_quantize_weight_supports_block_strategy ():
12+ module = torch .nn .Linear (7 , 5 , bias = False )
13+ quant_args = QuantizationArgs (
14+ num_bits = 8 ,
15+ symmetric = True ,
16+ strategy = "block" ,
17+ block_structure = [2 , 4 ],
18+ )
19+ module .quantization_scheme = QuantizationScheme (
20+ targets = ["Linear" ], weights = quant_args
21+ )
22+
23+ hessian = make_empty_hessian (module )
24+ hessian += torch .eye (hessian .shape [0 ], dtype = hessian .dtype , device = hessian .device )
25+
26+ loss , q_param_dict = quantize_weight (
27+ module = module ,
28+ quant_args = quant_args ,
29+ hessian = hessian ,
30+ blocksize = 3 ,
31+ )
32+
33+ assert loss >= 0
34+ assert q_param_dict ["weight" ].shape == module .weight .shape
35+ assert q_param_dict ["weight_scale" ].shape == (3 , 2 )
36+ assert q_param_dict ["weight_zero_point" ].shape == (3 , 2 )
37+ assert "weight_g_idx" not in q_param_dict
You can’t perform that action at this time.
0 commit comments