44# LICENSE file in the root directory of this source tree.
55
66import logging
7- from typing import Set , Type
7+ from collections .abc import Mapping
8+ from typing import Sequence , Set , Type
89
910import torch ._export .utils
1011import torch .fx
1819from executorch .backends .arm ._passes .fuse_equal_placeholders_pass import (
1920 FuseEqualPlaceholdersPass ,
2021)
22+ from executorch .backends .arm .tosa .dialect .shape import meta_has_shape_mark
2123from executorch .backends .arm .tosa .mapping import TosaSpecialDtype
2224from executorch .backends .transforms .utils import (
2325 create_constant_placeholder ,
@@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5355 super ().__init__ (* args , ** kwargs )
5456 self .exported_program = exported_program
5557
58+ @staticmethod
59+ def _is_tosa_dialect_op (target ) -> bool :
60+ target_str = str (target )
61+ return (
62+ "executorch.exir.dialects.backend._ops.tosa." in target_str
63+ or "<EdgeOpOverload: tosa." in target_str
64+ )
65+
66+ @staticmethod
67+ def _arg_contains_symbolic_shape (arg ) -> bool :
68+ if isinstance (arg , torch .fx .Node ):
69+ if meta_has_shape_mark (arg .meta ):
70+ return True
71+ return FuseConstantArgsPass ._arg_contains_symbolic_shape (
72+ arg .meta .get ("val" )
73+ )
74+ if isinstance (arg , torch .SymInt ):
75+ return True
76+ if isinstance (arg , Mapping ):
77+ return any (
78+ FuseConstantArgsPass ._arg_contains_symbolic_shape (k )
79+ or FuseConstantArgsPass ._arg_contains_symbolic_shape (v )
80+ for k , v in arg .items ()
81+ )
82+ if isinstance (arg , Sequence ) and not isinstance (arg , (str , bytes )):
83+ return any (
84+ FuseConstantArgsPass ._arg_contains_symbolic_shape (v ) for v in arg
85+ )
86+ return False
87+
5688 def _propagate_special_dtype (self , from_nodes , to_node , data ):
5789 """Propagate special dtype meta if it exists."""
5890 special_dtypes = set ()
@@ -83,21 +115,24 @@ def _fuse_nodes(self, node) -> bool:
83115 input_nodes = list (node .all_input_nodes )
84116 qparams = node .meta .get ("input_qparams" , None )
85117
86- def resolve_arg (arg ):
118+ def resolve_arg (arg , arg_index = None ):
119+ qparam = (
120+ qparams .get (arg_index ) if qparams and arg_index is not None else None
121+ )
87122 if isinstance (arg , torch .fx .Node ) and arg in input_nodes :
88- idx = input_nodes .index (arg )
89123 t = get_param_tensor (self .exported_program , arg )
90- # Check if qparams exist for this arg
91- if qparams and idx in qparams .keys ():
92- t = qparams [idx ].dequantize_value (t )
124+ if qparam is not None :
125+ t = qparam .dequantize_value (t )
93126 return t
94127 if isinstance (arg , tuple ):
95- return tuple (resolve_arg (x ) for x in arg )
128+ return tuple (resolve_arg (x , arg_index ) for x in arg )
96129 if isinstance (arg , list ):
97- return [resolve_arg (x ) for x in arg ]
130+ return [resolve_arg (x , arg_index ) for x in arg ]
98131 return arg
99132
100- new_args = tuple (resolve_arg (a ) for a in node .args )
133+ new_args = tuple (
134+ resolve_arg (arg , arg_index ) for arg_index , arg in enumerate (node .args )
135+ )
101136 new_kwargs = {k : resolve_arg (v ) for k , v in node .kwargs .items ()}
102137
103138 data = node .target (* new_args , ** new_kwargs )
@@ -139,13 +174,13 @@ def call(self, graph_module):
139174 for node in graph_module .graph .nodes :
140175 if node .op != "call_function" :
141176 continue
142- if node . target in [
143- exir_ops . backend . tosa . MATMUL . default ,
144- exir_ops . backend . tosa . RESCALE . default ,
145- exir_ops . backend . tosa . RESIZE . default ,
146- exir_ops . backend . tosa . TABLE . default ,
147- exir_ops . backend . tosa . TRANSPOSE . default ,
148- ] :
177+ # Don't fuse TOSA dialect ops as they do not have eager forward functions.
178+ # Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
179+ if (
180+ self . _is_tosa_dialect_op ( node . target )
181+ or self . _arg_contains_symbolic_shape ( node . args )
182+ or self . _arg_contains_symbolic_shape ( node . kwargs )
183+ ) :
149184 continue
150185
151186 input_nodes = node .all_input_nodes
@@ -161,7 +196,6 @@ def call(self, graph_module):
161196 )
162197 if not all (input_nodes_constant ):
163198 continue
164-
165199 try :
166200 did_fuse = self ._fuse_nodes (node )
167201 if did_fuse :
0 commit comments