Skip to content

Commit e7b38a3

Browse files
authored
Cortex-M backend: Support standalone clamp-type activations (#18767)
- Add support for quantized clamp-type activations in the Cortex-M pipeline by canonicalizing relu/hardtanh/clamp to quantized aten.clamp.default for standalone int8 paths - Extend activation fusion to cover max_pool2d. @freddan80 @per @zingo @oscarandersson8218 @digantdesai @Sebastian-Larsson @AdrianLundell @psiddh cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 369f5ca commit e7b38a3

12 files changed

Lines changed: 551 additions & 38 deletions

backends/cortex_m/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa
99
from .decompose_hardswish_pass import DecomposeHardswishPass # noqa
1010
from .decompose_mean_pass import DecomposeMeanPass # noqa
11+
from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa
1112
from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa
1213
from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa
1314
from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip

backends/cortex_m/passes/activation_fusion_pass.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -8,7 +8,10 @@
88

99
import executorch.backends.cortex_m.ops.operators # noqa: F401
1010
from executorch.backends.arm._passes.quant_args import QuantArgs
11-
from executorch.backends.cortex_m.passes.passes_utils import quantize_val
11+
from executorch.backends.cortex_m.passes.passes_utils import (
12+
get_activation_bounds,
13+
quantize_val,
14+
)
1215

1316
from executorch.exir.dialects._ops import ops as exir_ops
1417
from executorch.exir.pass_base import ExportPass
@@ -23,7 +26,7 @@ class ActivationFusionPass(ExportPass):
2326
"""Fuse activations into preceding Cortex-M quantized operators.
2427
2528
Supported activation patterns:
26-
q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq
29+
q-> [conv2d, linear, max_pool2d] -> [relu, hardtanh, hardsigmoid, clamp] -> dq
2730
2831
Fusing works by clamping the quantized output range (and zero-point when
2932
required) of the preceding Cortex-M operator, then removing the activation
@@ -37,10 +40,17 @@ class ActivationFusionPass(ExportPass):
3740
exir_ops.edge.aten.clamp.default,
3841
}
3942

43+
MAX_POOL_OPS = {
44+
exir_ops.edge.aten.max_pool2d.default,
45+
exir_ops.edge.aten.max_pool2d_with_indices.default,
46+
}
47+
4048
FUSE_OPS = {
4149
exir_ops.edge.aten.linear.default,
4250
exir_ops.edge.aten.convolution.default,
4351
exir_ops.edge.aten.add.Tensor,
52+
exir_ops.edge.aten.max_pool2d.default,
53+
exir_ops.edge.aten.max_pool2d_with_indices.default,
4454
}
4555

4656
def _get_validated_qparams(self, node, input_node):
@@ -63,30 +73,38 @@ def _get_validated_qparams(self, node, input_node):
6373
)
6474
return None
6575

66-
match node.target:
67-
case exir_ops.edge.aten.relu.default:
68-
quantized_min_val = quantize_val(0, scale, zp, qmin, qmax)
69-
quantized_max_val = qmax
70-
case exir_ops.edge.aten.hardtanh.default:
71-
quantized_min_val = quantize_val(node.args[1], scale, zp, qmin, qmax)
72-
quantized_max_val = quantize_val(node.args[2], scale, zp, qmin, qmax)
73-
case exir_ops.edge.aten.hardsigmoid.default:
74-
quantized_min_val = quantize_val(0, scale, zp, qmin, qmax)
75-
quantized_max_val = quantize_val(1, scale, zp, qmin, qmax)
76-
case exir_ops.edge.aten.clamp.default:
77-
quantized_min_val = (
78-
quantize_val(node.args[1], scale, zp, qmin, qmax)
79-
if node.args[1] is not None
80-
else qmin
81-
)
82-
# Last arg is removed if none, so check length of args here
83-
quantized_max_val = (
84-
quantize_val(node.args[2], scale, zp, qmin, qmax)
85-
if len(node.args) == 3
86-
else qmax
76+
bounds = get_activation_bounds(node)
77+
if bounds is None:
78+
logger.warning(
79+
"Cannot fuse activation %s because bounds are not compile-time scalars.",
80+
node.name,
81+
)
82+
return None
83+
min_val, max_val = bounds
84+
85+
quantized_min_val = (
86+
quantize_val(min_val, scale, zp, qmin, qmax)
87+
if min_val is not None
88+
else qmin
89+
)
90+
quantized_max_val = (
91+
quantize_val(max_val, scale, zp, qmin, qmax)
92+
if max_val is not None
93+
else qmax
94+
)
95+
96+
if input_node.target in self.MAX_POOL_OPS:
97+
if node.target == exir_ops.edge.aten.hardsigmoid.default:
98+
logger.warning(
99+
"Cannot fuse hardsigmoid %s after max_pool2d because max_pool2d requires matching input/output qparams.",
100+
node.name,
87101
)
88-
case _:
89-
raise RuntimeError(f"Unexpected target {node.target}.")
102+
return None
103+
# Max-pool keeps scale and zero-point unchanged and lowers fused
104+
# activation bounds separately, so only qmin/qmax need updating here.
105+
qparams_dict["qmin"] = int(quantized_min_val)
106+
qparams_dict["qmax"] = int(quantized_max_val)
107+
return qparams_dict
90108

91109
# If the minimal quantized value is larger than the qmin, it means that the quantized range contains
92110
# invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters.

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .convert_to_cortex_m_pass import ConvertToCortexMPass
2929
from .decompose_hardswish_pass import DecomposeHardswishPass
3030
from .decompose_mean_pass import DecomposeMeanPass
31+
from .quantized_clamp_activation_pass import QuantizedClampActivationPass
3132
from .quantized_op_fusion_pass import QuantizedOpFusionPass
3233
from .replace_quant_nodes_pass import ReplaceQuantNodesPass
3334

@@ -42,6 +43,7 @@ class CortexMPassManager(PassManager):
4243
ReplaceScalarWithTensorArgPass,
4344
ReplaceQuantNodesPass,
4445
ActivationFusionPass,
46+
QuantizedClampActivationPass,
4547
DecomposeHardswishPass,
4648
QuantizedOpFusionPass,
4749
ConvertToCortexMPass,

backends/cortex_m/passes/passes_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import math
9+
from typing import Any
910

1011
import torch
1112

@@ -21,6 +22,56 @@ def quantize_val(val, scale, zp, qmin, qmax):
2122
return float(min(max(torch.round(torch.Tensor([val / scale + zp])), qmin), qmax))
2223

2324

25+
def extract_constant_scalar(arg: Any) -> float | None:
26+
if arg is None:
27+
return None
28+
if isinstance(arg, (int, float)):
29+
return float(arg)
30+
if isinstance(arg, Node):
31+
if arg.op == "call_function" and arg.target in {
32+
exir_ops.edge.aten.full_like.default,
33+
exir_ops.edge.aten.full.default,
34+
torch.ops.aten.full_like.default,
35+
torch.ops.aten.full.default,
36+
}:
37+
fill_arg = arg.args[1] if len(arg.args) > 1 else None
38+
return extract_constant_scalar(fill_arg)
39+
val = arg.meta.get("val")
40+
if val is None:
41+
return None
42+
return extract_constant_scalar(val)
43+
return None
44+
45+
46+
def get_activation_bounds(node: Node) -> tuple[float | None, float | None] | None:
47+
bounds: tuple[float | None, float | None]
48+
match node.target:
49+
case exir_ops.edge.aten.relu.default | exir_ops.edge.aten.relu_.default:
50+
bounds = (0.0, None)
51+
case exir_ops.edge.aten.hardsigmoid.default:
52+
bounds = (0.0, 1.0)
53+
case exir_ops.edge.aten.hardtanh.default | exir_ops.edge.aten.hardtanh_.default:
54+
bounds = (
55+
extract_constant_scalar(node.args[1]),
56+
extract_constant_scalar(node.args[2]),
57+
)
58+
case exir_ops.edge.aten.clamp.default | exir_ops.edge.aten.clamp.Tensor:
59+
bounds = (
60+
extract_constant_scalar(node.args[1]) if len(node.args) > 1 else None,
61+
extract_constant_scalar(node.args[2]) if len(node.args) > 2 else None,
62+
)
63+
case _:
64+
return None
65+
66+
min_val, max_val = bounds
67+
if len(node.args) > 1 and min_val is None and node.args[1] is not None:
68+
return None
69+
if len(node.args) > 2 and max_val is None and node.args[2] is not None:
70+
return None
71+
72+
return bounds
73+
74+
2475
def dequantize_per_tensor_cmsis(
2576
qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int
2677
) -> torch.Tensor:
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from typing import Any
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
11+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
12+
get_output_qparams,
13+
)
14+
from executorch.backends.cortex_m.passes.passes_utils import (
15+
get_activation_bounds,
16+
quantize_val,
17+
)
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.pass_base import ExportPass
20+
from torch.fx import GraphModule, Node
21+
from torch.fx.passes.infra.pass_manager import PassResult
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class QuantizedClampActivationPass(ExportPass):
27+
"""Canonicalize remaining clamp-like activations on quantized tensors.
28+
29+
This pass runs after activation fusion, so any remaining relu/hardtanh/clamp
30+
still needs to execute in the quantized domain. It rewrites relu and
31+
hardtanh variants to `aten.clamp.default` and quantizes the clamp bounds so
32+
the portable kernel consumes and produces int8 tensors.
33+
"""
34+
35+
TARGETS = {
36+
exir_ops.edge.aten.relu.default,
37+
exir_ops.edge.aten.relu_.default,
38+
exir_ops.edge.aten.hardtanh.default,
39+
exir_ops.edge.aten.hardtanh_.default,
40+
exir_ops.edge.aten.clamp.default,
41+
exir_ops.edge.aten.clamp.Tensor,
42+
}
43+
44+
def _get_quantized_bounds(
45+
self, node: Node, qparams_dict: dict[str, Any]
46+
) -> tuple[int | None, int | None] | None:
47+
qmin = qparams_dict["qmin"]
48+
qmax = qparams_dict["qmax"]
49+
scale = qparams_dict["scale"]
50+
zp = qparams_dict["zp"]
51+
52+
bounds = get_activation_bounds(node)
53+
if bounds is None:
54+
logger.warning(
55+
"Cannot rewrite %s because bounds are not compile-time scalars.",
56+
node.name,
57+
)
58+
return None
59+
min_val, max_val = bounds
60+
61+
quantized_min = (
62+
int(quantize_val(min_val, scale, zp, qmin, qmax))
63+
if min_val is not None
64+
else None
65+
)
66+
quantized_max = (
67+
int(quantize_val(max_val, scale, zp, qmin, qmax))
68+
if max_val is not None
69+
else None
70+
)
71+
return quantized_min, quantized_max
72+
73+
def _is_quantized_int8_activation(self, node: Node) -> bool:
74+
input_node = node.args[0] if len(node.args) > 0 else None
75+
if not isinstance(input_node, Node):
76+
return False
77+
try:
78+
tensor = get_first_fake_tensor(input_node)
79+
except Exception:
80+
return False
81+
if tensor is None or tensor.dtype != torch.int8:
82+
return False
83+
84+
try:
85+
qparams_dict = get_output_qparams(node)[0]._asdict()
86+
except (ValueError, KeyError):
87+
logger.warning(
88+
"Cannot quantize clamp bounds for %s without output qparams.",
89+
node.name,
90+
)
91+
return False
92+
93+
scale = qparams_dict["scale"]
94+
zp = qparams_dict["zp"]
95+
if not isinstance(scale, float) or not isinstance(zp, int):
96+
logger.warning(
97+
"Cannot quantize clamp bounds for %s with non per-tensor qparams.",
98+
node.name,
99+
)
100+
return False
101+
102+
return True
103+
104+
def call(self, graph_module: GraphModule) -> PassResult:
105+
modified = False
106+
107+
for node in list(graph_module.graph.nodes):
108+
if node.op != "call_function" or node.target not in self.TARGETS:
109+
continue
110+
if not self._is_quantized_int8_activation(node):
111+
continue
112+
113+
qparams_dict = get_output_qparams(node)[0]._asdict()
114+
115+
quantized_bounds = self._get_quantized_bounds(node, qparams_dict)
116+
if quantized_bounds is None:
117+
continue
118+
119+
quantized_min, quantized_max = quantized_bounds
120+
node.target = exir_ops.edge.aten.clamp.default
121+
node.args = (node.args[0], quantized_min, quantized_max)
122+
modified = True
123+
124+
if modified:
125+
graph_module = super().call(graph_module).graph_module
126+
graph_module.graph.eliminate_dead_code()
127+
graph_module.recompile()
128+
129+
return PassResult(graph_module, modified)

backends/cortex_m/quantizer/quantization_configs.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
import operator
56
from typing import Any, Callable
67

78
import torch
@@ -86,10 +87,45 @@
8687
torch.ops.aten.max_pool2d_with_indices.default,
8788
}
8889

90+
POOL_FUSED_ACTIVATION_TARGETS = {
91+
torch.ops.aten.relu.default,
92+
torch.ops.aten.relu_.default,
93+
torch.ops.aten.hardtanh.default,
94+
torch.ops.aten.hardtanh_.default,
95+
torch.ops.aten.clamp.default,
96+
torch.ops.aten.clamp_.default,
97+
}
98+
8999

90100
class CortexMQuantizationConfig(QuantizationConfig):
91101
"""Configures quantization, while enforcing cortex-m specific constraints."""
92102

103+
@staticmethod
104+
def _get_shared_pool_input(node: Node | None) -> Node | None:
105+
if node is None or len(node.args) == 0:
106+
return None
107+
108+
input_node = node.args[0]
109+
if not isinstance(input_node, Node):
110+
return None
111+
112+
if input_node.target in POOL_SHARE_OUTPUT_TARGETS:
113+
if len(input_node.args) > 0 and isinstance(input_node.args[0], Node):
114+
return input_node.args[0]
115+
return None
116+
117+
if input_node.target == operator.getitem and len(input_node.args) > 0:
118+
pool_node = input_node.args[0]
119+
if (
120+
isinstance(pool_node, Node)
121+
and pool_node.target in POOL_SHARE_OUTPUT_TARGETS
122+
and len(pool_node.args) > 0
123+
and isinstance(pool_node.args[0], Node)
124+
):
125+
return pool_node.args[0]
126+
127+
return None
128+
93129
def get_input_act_qspec(
94130
self, node: Node | None = None, input_node: Node | None = None
95131
) -> QuantizationSpecBase | None:
@@ -117,6 +153,10 @@ def get_output_act_qspec(
117153
if isinstance(input_node, Node):
118154
return SharedQuantizationSpec((input_node, node))
119155
return super().get_output_act_qspec()
156+
if node is not None and node.target in POOL_FUSED_ACTIVATION_TARGETS:
157+
shared_pool_input = self._get_shared_pool_input(node)
158+
if shared_pool_input is not None:
159+
return SharedQuantizationSpec(shared_pool_input)
120160
return super().get_output_act_qspec()
121161

122162
def get_weight_qspec(self, node: Node | None = None) -> QuantizationSpecBase | None:

0 commit comments

Comments
 (0)