Skip to content

Commit e0ef203

Browse files
authored
Introduce create_node helper (#147)
This commit introduces create_node helper for keeping some meta information. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 67d5ec1 commit e0ef203

24 files changed

+314
-144
lines changed

tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
2525
from tico.utils import logging
2626
from tico.utils.errors import NotYetSupportedError
27+
from tico.utils.graph import create_node
2728
from tico.utils.passes import PassBase, PassResult
2829
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2930
from tico.utils.utils import quant_min_max, set_new_meta_val
@@ -145,9 +146,11 @@ def _insert_quantize_op_before(node, inp):
145146

146147
with graph.inserting_before(node):
147148
q_args = (inp, scale, zerop, min_, max_, dtype)
148-
quantize = graph.call_function(
149+
quantize = create_node(
150+
graph,
149151
torch.ops.quantized_decomposed.quantize_per_tensor.default,
150152
args=q_args,
153+
origin=node,
151154
)
152155
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
153156
set_new_meta_val(quantize)
@@ -166,7 +169,8 @@ def _insert_quantize_op_after(node):
166169
dtype = getattr(torch, qparam.dtype)
167170
with graph.inserting_after(node):
168171
q_args = (node, scale, zerop, min_, max_, dtype)
169-
quantize = graph.call_function(
172+
quantize = create_node(
173+
graph,
170174
torch.ops.quantized_decomposed.quantize_per_tensor.default,
171175
args=q_args,
172176
)

tico/passes/cast_aten_where_arg_type.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from tico.serialize.circle_mapping import extract_torch_dtype
2323
from tico.utils import logging
24+
from tico.utils.graph import create_node
2425
from tico.utils.passes import PassBase, PassResult
2526
from tico.utils.trace_decorators import (
2627
trace_const_diff_on_pass,
@@ -158,10 +159,12 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
158159
f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
159160
)
160161
with graph_module.graph.inserting_after(to_cast):
161-
cast = graph_module.graph.call_function(
162+
cast = create_node(
163+
graph,
162164
torch.ops.aten._to_copy.default,
163165
args=(to_cast,),
164166
kwargs={"dtype": dtype_to_cast},
167+
origin=to_cast,
165168
)
166169
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
167170
set_new_meta_val(cast)

tico/passes/cast_mixed_type_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from tico.serialize.circle_mapping import extract_torch_dtype
2828
from tico.utils import logging
29+
from tico.utils.graph import create_node
2930
from tico.utils.passes import PassBase, PassResult
3031
from tico.utils.trace_decorators import trace_graph_diff_on_pass
3132
from tico.utils.utils import is_target_node, set_new_meta_val
@@ -126,10 +127,12 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
126127

127128
if isinstance(arg_to_promote, torch.fx.Node):
128129
with graph.inserting_after(arg_to_promote):
129-
to_copy = graph.call_function(
130+
to_copy = create_node(
131+
graph,
130132
torch.ops.aten._to_copy.default,
131133
(arg_to_promote,),
132134
{"dtype": type_to_promote},
135+
origin=arg_to_promote,
133136
)
134137
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
135138
set_new_meta_val(to_copy)

tico/passes/convert_conv1d_to_conv2d.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tico.serialize.circle_graph import extract_shape
2323
from tico.utils import logging
2424
from tico.utils.errors import NotYetSupportedError
25+
from tico.utils.graph import create_node
2526
from tico.utils.passes import PassBase, PassResult
2627
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2728
from tico.utils.utils import is_target_node
@@ -89,15 +90,19 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo
8990
)
9091

9192
with graph.inserting_after(input):
92-
input_unsqueeze = graph_module.graph.call_function(
93+
input_unsqueeze = create_node(
94+
graph,
9395
torch.ops.aten.unsqueeze.default,
9496
args=(input, 3),
97+
origin=input,
9598
)
9699

97100
with graph.inserting_after(weight):
98-
weight_unsqueeze = graph_module.graph.call_function(
101+
weight_unsqueeze = create_node(
102+
graph,
99103
torch.ops.aten.unsqueeze.default,
100104
args=(weight, 3),
105+
origin=weight,
101106
)
102107

103108
with graph.inserting_before(node):
@@ -106,7 +111,8 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo
106111
elif isinstance(padding, str):
107112
conv2d_op = torch.ops.aten.conv2d.padding
108113

109-
conv2d = graph_module.graph.call_function(
114+
conv2d = create_node(
115+
graph,
110116
conv2d_op,
111117
args=(
112118
input_unsqueeze,
@@ -118,9 +124,11 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo
118124
groups,
119125
),
120126
kwargs=node.kwargs,
127+
origin=node,
121128
)
122129

123-
conv_out_squeeze = graph_module.graph.call_function(
130+
conv_out_squeeze = create_node(
131+
graph,
124132
torch.ops.aten.squeeze.dims,
125133
args=(conv2d, [3]),
126134
)

tico/passes/convert_layout_op_to_reshape.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tico.passes import ops
2323
from tico.serialize.circle_mapping import extract_shape
2424
from tico.utils import logging
25+
from tico.utils.graph import create_node
2526
from tico.utils.passes import PassBase, PassResult
2627
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2728
from tico.utils.validate_args_kwargs import SqueezeArgs, UnSqueezeArgs, ViewArgs
@@ -48,11 +49,11 @@ def convert(node, input):
4849
out_shape = list(extract_shape(node))
4950

5051
with graph.inserting_after(node):
51-
reshape_node = graph.call_function(
52+
reshape_node = create_node(
53+
graph,
5254
torch.ops.aten.reshape.default,
5355
args=(input, out_shape),
5456
)
55-
5657
node.replace_all_uses_with(reshape_node, propagate_meta=True)
5758

5859
logger.debug(f"{node.name} is replaced with {reshape_node.name}")

tico/passes/convert_repeat_to_expand_copy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.export import ExportedProgram
2121

2222
from tico.utils import logging
23+
from tico.utils.graph import create_node
2324
from tico.utils.passes import PassBase, PassResult
2425
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2526
from tico.utils.utils import is_target_node
@@ -71,8 +72,10 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
7172
expand_copy_args = (tensor, size)
7273

7374
with graph.inserting_after(node):
74-
expand_copy_node = graph.call_function(
75-
torch.ops.aten.expand_copy.default, args=expand_copy_args
75+
expand_copy_node = create_node(
76+
graph,
77+
torch.ops.aten.expand_copy.default,
78+
args=expand_copy_args,
7679
)
7780
node.replace_all_uses_with(expand_copy_node, propagate_meta=True)
7881

tico/passes/convert_to_relu6.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.export import ExportedProgram
2121

2222
from tico.utils import logging
23+
from tico.utils.graph import create_node
2324
from tico.utils.passes import PassBase, PassResult
2425
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2526
from tico.utils.validate_args_kwargs import ClampArgs, HardTanhArgs
@@ -58,7 +59,7 @@ def convert(self, exported_program, node):
5859
input = args.input
5960

6061
with graph.inserting_after(node):
61-
relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
62+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
6263
node.replace_all_uses_with(relu_node, propagate_meta=True)
6364

6465

@@ -84,7 +85,7 @@ def convert(self, exported_program, node):
8485
input = args.input
8586

8687
with graph.inserting_after(node):
87-
relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
88+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
8889
node.replace_all_uses_with(relu_node, propagate_meta=True)
8990

9091

@@ -140,7 +141,7 @@ def convert(self, exported_program, node):
140141
input = prev_args.input
141142

142143
with graph.inserting_after(node):
143-
relu_node = graph.call_function(torch.ops.aten.relu6.default, args=(input,))
144+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
144145
node.replace_all_uses_with(relu_node, propagate_meta=True)
145146

146147

tico/passes/decompose_addmm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from tico.serialize.circle_mapping import extract_shape
2323
from tico.utils import logging
24-
from tico.utils.graph import add_placeholder
24+
from tico.utils.graph import add_placeholder, create_node
2525
from tico.utils.passes import PassBase, PassResult
2626
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2727
from tico.utils.utils import is_target_node, set_new_meta_val
@@ -78,7 +78,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
7878

7979
with graph.inserting_before(node):
8080
# out = beta * input + alpha * (mat1 @ mat2)
81-
matmul = graph.call_function(torch.ops.aten.mm.default, (mat1, mat2))
81+
matmul = create_node(
82+
graph, torch.ops.aten.mm.default, (mat1, mat2), origin=node
83+
)
8284
set_new_meta_val(matmul)
8385

8486
if beta == 1:
@@ -90,7 +92,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
9092
f"{node.name}_beta_zeros",
9193
)
9294
else:
93-
bias = graph.call_function(torch.ops.aten.mul.Tensor, (input, beta))
95+
bias = create_node(
96+
graph, torch.ops.aten.mul.Tensor, (input, beta), origin=node
97+
)
9498

9599
if alpha == 1:
96100
scaled_matmul: torch.fx.Node | torch.Tensor = matmul
@@ -101,12 +105,12 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
101105
f"{node.name}_alpha_zeros",
102106
)
103107
else:
104-
scaled_matmul = graph.call_function(
105-
torch.ops.aten.mul.Tensor, (matmul, alpha)
108+
scaled_matmul = create_node(
109+
graph, torch.ops.aten.mul.Tensor, (matmul, alpha), origin=node
106110
)
107111

108-
result = graph.call_function(
109-
torch.ops.aten.add.Tensor, (bias, scaled_matmul)
112+
result = create_node(
113+
graph, torch.ops.aten.add.Tensor, (bias, scaled_matmul)
110114
)
111115

112116
node.replace_all_uses_with(result, propagate_meta=True)

tico/passes/decompose_batch_norm.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tico.utils.errors import NotYetSupportedError
2525
from tico.utils.graph import (
2626
add_placeholder,
27+
create_node,
2728
get_first_user_input,
2829
get_torch_buffer_value,
2930
get_torch_param_value,
@@ -32,16 +33,10 @@
3233
)
3334
from tico.utils.passes import PassBase, PassResult
3435
from tico.utils.trace_decorators import trace_graph_diff_on_pass
35-
from tico.utils.utils import fill_meta_val, is_target_node
36+
from tico.utils.utils import is_target_node
3637
from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
3738

3839

39-
def insert_node(graph: torch.fx.Graph, operation, args):
40-
new_node = graph.call_function(operation, args)
41-
42-
return new_node
43-
44-
4540
@trace_graph_diff_on_pass
4641
class DecomposeBatchNorm(PassBase):
4742
"""
@@ -173,19 +168,20 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
173168
)
174169

175170
with gm.graph.inserting_before(node):
176-
mul = graph.call_function(
171+
mul = create_node(
172+
graph,
177173
torch.ops.aten.mul.Tensor,
178174
args=(input_, mul_const_node),
175+
origin=node,
179176
)
180-
add = graph.call_function(
177+
add = create_node(
178+
graph,
181179
torch.ops.aten.add.Tensor,
182180
args=(mul, add_const_node),
183181
)
184-
# Not set meta for propagating replacing get_item's meta.
185182
get_item, *_ = node.users.keys()
186183
get_item.replace_all_uses_with(add, propagate_meta=True)
187184

188-
fill_meta_val(exported_program)
189185
logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
190186
modified = True
191187

tico/passes/decompose_fake_quantize.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.export import ExportedProgram
2424

2525
from tico.utils import logging
26+
from tico.utils.graph import create_node
2627
from tico.utils.passes import PassBase, PassResult
2728
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2829
from tico.utils.validate_args_kwargs import FakeQuantizePerChannelArgs
@@ -69,6 +70,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
6970
modified = False
7071

7172
gm = exported_program.graph_module
73+
g = gm.graph
7274
qd = torch.ops.quantized_decomposed # type: ignore[return]
7375
for node in gm.graph.nodes:
7476
if node.op != "call_function":
@@ -83,17 +85,19 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
8385
**{"dtype": get_quant_type(quant_min, quant_max)},
8486
}
8587
with gm.graph.inserting_before(node):
86-
quant = gm.graph.call_function(
88+
quant = create_node(
89+
g,
8790
qd.quantize_per_tensor.default,
8891
args=node.args,
8992
kwargs=quant_kwargs,
93+
origin=node,
9094
)
91-
dequnt = gm.graph.call_function(
95+
dequnt = create_node(
96+
g,
9297
qd.dequantize_per_tensor.default,
9398
args=(quant, *quant.args[1:]),
9499
kwargs=quant.kwargs,
95100
)
96-
# Not set meta for propagating replacing node's meta.
97101
node.replace_all_uses_with(dequnt, propagate_meta=True)
98102
modified = True
99103

@@ -107,17 +111,19 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
107111
**{"dtype": get_quant_type(quant_min, quant_max)},
108112
}
109113
with gm.graph.inserting_before(node):
110-
quant = gm.graph.call_function(
114+
quant = create_node(
115+
g,
111116
qd.quantize_per_channel.default,
112117
args=node.args,
113118
kwargs=quant_kwargs,
119+
origin=node,
114120
)
115-
dequnt = gm.graph.call_function(
121+
dequnt = create_node(
122+
g,
116123
qd.dequantize_per_channel.default,
117124
args=(quant, *quant.args[1:]),
118125
kwargs=quant.kwargs,
119126
)
120-
# Not set meta for propagating replacing node's meta.
121127
node.replace_all_uses_with(dequnt, propagate_meta=True)
122128
modified = True
123129

0 commit comments

Comments
 (0)