1- # Copyright 2025 Arm Limited and/or its affiliates.
1+ # Copyright 2025-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
88
99import executorch .backends .cortex_m .ops .operators # noqa: F401
1010from executorch .backends .arm ._passes .quant_args import QuantArgs
11- from executorch .backends .cortex_m .passes .passes_utils import quantize_val
11+ from executorch .backends .cortex_m .passes .passes_utils import (
12+ get_activation_bounds ,
13+ quantize_val ,
14+ )
1215
1316from executorch .exir .dialects ._ops import ops as exir_ops
1417from executorch .exir .pass_base import ExportPass
@@ -23,7 +26,7 @@ class ActivationFusionPass(ExportPass):
2326 """Fuse activations into preceding Cortex-M quantized operators.
2427
2528 Supported activation patterns:
26- q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq
29+ q-> [conv2d, linear, max_pool2d ] -> [relu, hardtanh, hardsigmoid, clamp ] -> dq
2730
2831 Fusing works by clamping the quantized output range (and zero-point when
2932 required) of the preceding Cortex-M operator, then removing the activation
@@ -37,10 +40,17 @@ class ActivationFusionPass(ExportPass):
3740 exir_ops .edge .aten .clamp .default ,
3841 }
3942
43+ MAX_POOL_OPS = {
44+ exir_ops .edge .aten .max_pool2d .default ,
45+ exir_ops .edge .aten .max_pool2d_with_indices .default ,
46+ }
47+
4048 FUSE_OPS = {
4149 exir_ops .edge .aten .linear .default ,
4250 exir_ops .edge .aten .convolution .default ,
4351 exir_ops .edge .aten .add .Tensor ,
52+ exir_ops .edge .aten .max_pool2d .default ,
53+ exir_ops .edge .aten .max_pool2d_with_indices .default ,
4454 }
4555
4656 def _get_validated_qparams (self , node , input_node ):
@@ -63,30 +73,38 @@ def _get_validated_qparams(self, node, input_node):
6373 )
6474 return None
6575
66- match node .target :
67- case exir_ops .edge .aten .relu .default :
68- quantized_min_val = quantize_val (0 , scale , zp , qmin , qmax )
69- quantized_max_val = qmax
70- case exir_ops .edge .aten .hardtanh .default :
71- quantized_min_val = quantize_val (node .args [1 ], scale , zp , qmin , qmax )
72- quantized_max_val = quantize_val (node .args [2 ], scale , zp , qmin , qmax )
73- case exir_ops .edge .aten .hardsigmoid .default :
74- quantized_min_val = quantize_val (0 , scale , zp , qmin , qmax )
75- quantized_max_val = quantize_val (1 , scale , zp , qmin , qmax )
76- case exir_ops .edge .aten .clamp .default :
77- quantized_min_val = (
78- quantize_val (node .args [1 ], scale , zp , qmin , qmax )
79- if node .args [1 ] is not None
80- else qmin
81- )
82- # Last arg is removed if none, so check length of args here
83- quantized_max_val = (
84- quantize_val (node .args [2 ], scale , zp , qmin , qmax )
85- if len (node .args ) == 3
86- else qmax
76+ bounds = get_activation_bounds (node )
77+ if bounds is None :
78+ logger .warning (
79+ "Cannot fuse activation %s because bounds are not compile-time scalars." ,
80+ node .name ,
81+ )
82+ return None
83+ min_val , max_val = bounds
84+
85+ quantized_min_val = (
86+ quantize_val (min_val , scale , zp , qmin , qmax )
87+ if min_val is not None
88+ else qmin
89+ )
90+ quantized_max_val = (
91+ quantize_val (max_val , scale , zp , qmin , qmax )
92+ if max_val is not None
93+ else qmax
94+ )
95+
96+ if input_node .target in self .MAX_POOL_OPS :
97+ if node .target == exir_ops .edge .aten .hardsigmoid .default :
98+ logger .warning (
99+ "Cannot fuse hardsigmoid %s after max_pool2d because max_pool2d requires matching input/output qparams." ,
100+ node .name ,
87101 )
88- case _:
89- raise RuntimeError (f"Unexpected target { node .target } ." )
102+ return None
103+ # Max-pool keeps scale and zero-point unchanged and lowers fused
104+ # activation bounds separately, so only qmin/qmax need updating here.
105+ qparams_dict ["qmin" ] = int (quantized_min_val )
106+ qparams_dict ["qmax" ] = int (quantized_max_val )
107+ return qparams_dict
90108
91109 # If the minimal quantized value is larger than the qmin, it means that the quantized range contains
92110 # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters.
0 commit comments