Skip to content

Commit 2d995bc

Browse files
committed
Arm backend: Fix quantized constant-folding for aten.cat lists (#18971)
FuseConstantArgsPass resolved input_qparams by flattened input-node index, while FoldAndAnnotateQParamsPass stores them by top-level argument index. For aten.cat with a list-valued tensor argument, this caused only the first tensor to be dequantized before folding, which corrupted the fused constant. Resolve qparams by top-level argument index and propagate that qparam through nested list and tuple arguments. Add a regression test for quantized aten.cat constant folding with list-valued tensor inputs. Signed-off-by: Per Held <per.held@arm.com> Change-Id: I6e1a012d82a5dbeecb403c440a2944953dd5cba7
1 parent e7b38a3 commit 2d995bc

2 files changed

Lines changed: 107 additions & 8 deletions

File tree

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,21 +83,24 @@ def _fuse_nodes(self, node) -> bool:
8383
input_nodes = list(node.all_input_nodes)
8484
qparams = node.meta.get("input_qparams", None)
8585

86-
def resolve_arg(arg):
86+
def resolve_arg(arg, arg_index=None):
87+
qparam = (
88+
qparams.get(arg_index) if qparams and arg_index is not None else None
89+
)
8790
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
88-
idx = input_nodes.index(arg)
8991
t = get_param_tensor(self.exported_program, arg)
90-
# Check if qparams exist for this arg
91-
if qparams and idx in qparams.keys():
92-
t = qparams[idx].dequantize_value(t)
92+
if qparam is not None:
93+
t = qparam.dequantize_value(t)
9394
return t
9495
if isinstance(arg, tuple):
95-
return tuple(resolve_arg(x) for x in arg)
96+
return tuple(resolve_arg(x, arg_index) for x in arg)
9697
if isinstance(arg, list):
97-
return [resolve_arg(x) for x in arg]
98+
return [resolve_arg(x, arg_index) for x in arg]
9899
return arg
99100

100-
new_args = tuple(resolve_arg(a) for a in node.args)
101+
new_args = tuple(
102+
resolve_arg(arg, arg_index) for arg_index, arg in enumerate(node.args)
103+
)
101104
new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()}
102105

103106
data = node.target(*new_args, **new_kwargs)

backends/arm/test/passes/test_fuse_constant_ops_pass.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
ComputeConstantOpsAOTPass,
1212
FuseConstantArgsPass,
1313
)
14+
from executorch.backends.arm._passes.quant_args import QuantArgs
1415
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1517
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
18+
from executorch.backends.arm.tosa import TosaSpecification
19+
from executorch.backends.test.harness.stages import StageType
1620

1721
input_t = Tuple[torch.Tensor] # Input x
1822
input_t2 = Tuple[torch.Tensor, torch.Tensor]
@@ -116,6 +120,52 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
116120
return torch.cat((a, b), dim=0)
117121

118122

123+
class QuantizedCatConstantBuffers(torch.nn.Module):
124+
def __init__(self) -> None:
125+
super().__init__()
126+
self.register_buffer(
127+
"horizontal_ramp",
128+
torch.tensor(
129+
[
130+
[
131+
[
132+
[-95, -32, 32, 95, 0],
133+
[-95, -32, 32, 95, 0],
134+
[-95, -32, 32, 95, 0],
135+
[-95, -32, 32, 95, 0],
136+
]
137+
]
138+
],
139+
dtype=torch.int8,
140+
),
141+
)
142+
self.register_buffer(
143+
"vertical_ramp",
144+
torch.tensor(
145+
[
146+
[
147+
[
148+
[-95, -95, -95, -95, -95],
149+
[-32, -32, -32, -32, -32],
150+
[32, 32, 32, 32, 32],
151+
[95, 95, 95, 95, 95],
152+
]
153+
]
154+
],
155+
dtype=torch.int8,
156+
),
157+
)
158+
159+
def forward(self) -> torch.Tensor:
160+
return torch.cat(
161+
(
162+
cast(torch.Tensor, self.horizontal_ramp),
163+
cast(torch.Tensor, self.vertical_ramp),
164+
),
165+
dim=1,
166+
)
167+
168+
119169
modules: Dict[str, ModuleWithFuseAttrs] = {
120170
"fuse_parameter": cast(ModuleWithFuseAttrs, FuseParameter()),
121171
"fuse_buffer": cast(ModuleWithFuseAttrs, FuseBuffer()),
@@ -174,3 +224,49 @@ def test_fuse_constant_args_tosa_INT_cat(module: ModuleWithFuseAttrs) -> None:
174224
],
175225
)
176226
pipeline.run()
227+
228+
229+
def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None:
230+
qargs = QuantArgs(
231+
scale=1.0 / 127.0,
232+
zp=0,
233+
qmin=-127,
234+
qmax=127,
235+
dtype=torch.int8,
236+
)
237+
module = QuantizedCatConstantBuffers()
238+
compile_spec = common.get_tosa_compile_spec(
239+
TosaSpecification.create_from_string("TOSA-1.0+FP")
240+
)
241+
tester = ArmTester(module, example_inputs=(), compile_spec=compile_spec)
242+
tester.export().to_edge()
243+
exported_program = tester.get_artifact(StageType.TO_EDGE).exported_program()
244+
245+
cat_node = next(
246+
node
247+
for node in exported_program.graph_module.graph.nodes
248+
if node.op == "call_function"
249+
)
250+
cat_node.meta["input_qparams"] = {0: qargs}
251+
cat_node.meta["output_qparams"] = {0: qargs}
252+
253+
pass_result = FuseConstantArgsPass(exported_program).call(
254+
exported_program.graph_module
255+
)
256+
257+
assert list(exported_program.state_dict) == ["aten_cat_default_fused_const"]
258+
torch.testing.assert_close(
259+
exported_program.state_dict["aten_cat_default_fused_const"],
260+
torch.cat(
261+
(
262+
cast(torch.Tensor, module.horizontal_ramp),
263+
cast(torch.Tensor, module.vertical_ramp),
264+
),
265+
dim=1,
266+
),
267+
)
268+
assert [
269+
node.name
270+
for node in pass_result.graph_module.graph.nodes
271+
if node.op == "placeholder"
272+
] == ["aten_cat_default_fused_const"]

0 commit comments

Comments
 (0)