Skip to content

Commit e87b415

Browse files
committed
Support fake-quantized linear with fp32 bias
This supports fake-quantized linear with fp32 bias. TICO-DCO-1.0-Signed-off-by: Hyukjin Jeong <hj1.jeong@samsung.com>
1 parent e0ef203 commit e87b415

File tree

6 files changed

+173
-2
lines changed

6 files changed

+173
-2
lines changed

test/modules/op/linear.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import torch
1616
from torch.export import Dim
17+
from torch.nn import functional as F
18+
19+
from test.utils.tag import test_without_inference
1720

1821

1922
class SimpleLinear(torch.nn.Module):
@@ -68,3 +71,29 @@ def forward(self, arg, attn_mask):
6871

6972
def get_example_inputs(self):
7073
return (torch.randn(3, 3), None)
74+
75+
76+
@test_without_inference
77+
class FQLinearWithFp32Bias(torch.nn.Module):
78+
def __init__(self):
79+
super().__init__()
80+
self.weight = torch.nn.Parameter(torch.ones(3, 3))
81+
self.bias = torch.nn.Parameter(torch.ones(3))
82+
83+
def forward(self, inp):
84+
scale = torch.ones(3)
85+
zero_point = torch.zeros(3)
86+
axis = 0
87+
qmin = -32768
88+
qmax = 32767
89+
quant_inp = torch.fake_quantize_per_tensor_affine(inp, 1.0, 0, qmin, qmax)
90+
quant_weight = torch.fake_quantize_per_channel_affine(
91+
self.weight, scale, zero_point, axis, qmin, qmax
92+
)
93+
output = F.linear(quant_inp, quant_weight, bias=self.bias)
94+
output = torch.fake_quantize_per_tensor_affine(output, 1.0, 0, qmin, qmax)
95+
96+
return output
97+
98+
def get_example_inputs(self):
99+
return (torch.randn(3, 3),)

test/pt2_to_circle_test/builder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module):
4141

4242
# Get tags
4343
self.test_without_pt2: bool = is_tagged(self.nnmodule, "test_without_pt2")
44+
self.test_without_inference: bool = is_tagged(
45+
self.nnmodule, "test_without_inference"
46+
)
4447

4548
# Set tolerance
4649
self.tolerance = {}
@@ -73,11 +76,15 @@ def wrapper(s):
7376
else:
7477

7578
def wrapper(s):
76-
self._run(without_pt2=self.test_without_pt2, dynamic=dynamic)
79+
self._run(
80+
without_pt2=self.test_without_pt2,
81+
dynamic=dynamic,
82+
without_inference=self.test_without_inference,
83+
)
7784

7885
return wrapper
7986

80-
def _run(self, without_pt2=False, dynamic: bool = False):
87+
def _run(self, without_pt2=False, dynamic: bool = False, without_inference=False):
8188
dynamic_shapes = None
8289
if dynamic:
8390
assert hasattr(self.nnmodule, "get_dynamic_shapes")
@@ -120,6 +127,9 @@ def _run(self, without_pt2=False, dynamic: bool = False):
120127

121128
verify_circle(circle_model_path, opt_circle_model_path)
122129

130+
if without_inference:
131+
return
132+
123133
USE_ONERT = os.environ.get("CCEX_RUNTIME") == "onert" or dynamic
124134
if self.use_onert or USE_ONERT:
125135
circle_result = infer_circle(

test/utils/tag.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def __init__(self, *args_, **kwargs_):
4848
return lambda x: x
4949

5050

51+
def test_without_inference(orig_class):
52+
setattr(orig_class, "__tag_test_without_inference", True)
53+
return orig_class
54+
55+
5156
def test_without_pt2(orig_class):
5257
setattr(orig_class, "__tag_test_without_pt2", True)
5358
return orig_class
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch.fx
19+
import copy
20+
21+
import torch
22+
from torch.export import ExportedProgram
23+
24+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
25+
from tico.utils import logging
26+
from tico.utils.graph import add_placeholder, get_torch_param_value, is_torch_param
27+
from tico.utils.passes import PassBase, PassResult
28+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
29+
from tico.utils.validate_args_kwargs import LinearArgs
30+
31+
32+
@trace_graph_diff_on_pass
33+
class QuantizeBias(PassBase):
34+
"""
35+
Quantize bias.
36+
37+
This pass identifies fp32 biases, quantizes them using scales of input and weights.
38+
39+
This pass assumes that if bias is fp32, input and weights must have been quantized.
40+
"""
41+
42+
def __init__(self):
43+
super().__init__()
44+
45+
def call(self, exported_program: ExportedProgram) -> PassResult:
46+
logger = logging.getLogger(__name__)
47+
48+
graph_module = exported_program.graph_module
49+
graph: torch.fx.Graph = graph_module.graph
50+
for node in graph.nodes:
51+
if node.op != "call_function":
52+
continue
53+
if node.target == torch.ops.aten.linear.default:
54+
lin_args = LinearArgs(*node.args, **node.kwargs)
55+
inp = lin_args.input
56+
weights = lin_args.weight
57+
bias = lin_args.bias
58+
59+
if bias is None:
60+
continue
61+
62+
# Only support bias is Parameter
63+
# TODO Is it possible that bias is not Parameter?
64+
if not is_torch_param(bias, exported_program):
65+
continue
66+
67+
bias_val: torch.Tensor = get_torch_param_value(bias, exported_program)
68+
if bias_val.dtype != torch.float32:
69+
continue
70+
71+
if QPARAM_KEY not in inp.meta:
72+
continue
73+
74+
if QPARAM_KEY not in weights.meta:
75+
continue
76+
77+
quant_dtype = None
78+
if inp.meta[QPARAM_KEY].dtype == "int16":
79+
quant_dtype = torch.int64
80+
elif inp.meta[QPARAM_KEY].dtype == "uint8":
81+
quant_dtype = torch.int32
82+
else:
83+
continue
84+
85+
type_info = torch.iinfo(quant_dtype)
86+
87+
assert quant_dtype is not None
88+
89+
i_scale = inp.meta[QPARAM_KEY].scale
90+
w_scale = weights.meta[QPARAM_KEY].scale
91+
92+
assert i_scale is not None
93+
assert w_scale is not None
94+
assert len(i_scale) == 1
95+
assert len(w_scale) == bias_val.shape[0]
96+
97+
bias_scale = torch.tensor(i_scale) * torch.tensor(w_scale)
98+
q_bias = torch.round(bias_val / bias_scale)
99+
q_bias = torch.clamp(q_bias, min=type_info.min, max=type_info.max)
100+
q_bias = q_bias.to(quant_dtype)
101+
102+
q_bias_node = add_placeholder(exported_program, q_bias, bias.name)
103+
104+
qparam = QuantParam()
105+
qparam.scale = bias_scale.tolist()
106+
qparam.zero_point = [0] * len(qparam.scale)
107+
qparam.dtype = to_qparam_dtype(quant_dtype)
108+
qparam.quantized_dimension = 0
109+
q_bias_node.meta[QPARAM_KEY] = qparam
110+
111+
node.update_arg(2, q_bias_node)
112+
113+
logger.debug(f"Bias ({bias.name}) is quantized to {q_bias_node.name}.")
114+
115+
# TODO Support more ops.
116+
117+
graph.eliminate_dead_code()
118+
graph.lint()
119+
graph_module.recompile()
120+
121+
# Run only once.
122+
return PassResult(False)

tico/experimental/quantization/passes/remove_weight_dequant_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
145145
if isinstance(dq_args, DequantizePerChannelArgs):
146146
scales = get_constant(exported_program, dq_args.scales)
147147
zero_ps = get_constant(exported_program, dq_args.zero_points)
148+
149+
# Sometimes users can give fp32 zero point. Let's update dtype here.
150+
zero_ps = zero_ps.to(torch.int64)
148151
quant_param.scale = scales.tolist()
149152
quant_param.zero_point = zero_ps.tolist()
150153
assert quant_param.zero_point is not None # To avoid mypy error

tico/utils/convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tico.experimental.quantization.passes.propagate_qparam_forward import (
3131
PropagateQParamForward,
3232
)
33+
from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
3334
from tico.experimental.quantization.passes.remove_weight_dequant_op import (
3435
RemoveWeightDequantOp,
3536
)
@@ -250,6 +251,7 @@ def convert_exported_module_to_circle(
250251
RemoveWeightDequantOp(),
251252
PropagateQParamForward(),
252253
PropagateQParamBackward(),
254+
QuantizeBias(),
253255
InsertQuantizeOnDtypeMismatch(),
254256
]
255257
)

0 commit comments

Comments
 (0)