2121
2222from tico .serialize .circle_mapping import extract_shape
2323from tico .utils import logging
24- from tico .utils .graph import add_placeholder
24+ from tico .utils .graph import add_placeholder , create_node
2525from tico .utils .passes import PassBase , PassResult
2626from tico .utils .trace_decorators import trace_graph_diff_on_pass
2727from tico .utils .utils import is_target_node , set_new_meta_val
@@ -78,7 +78,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
7878
7979 with graph .inserting_before (node ):
8080 # out = beta * input + alpha * (mat1 @ mat2)
81- matmul = graph .call_function (torch .ops .aten .mm .default , (mat1 , mat2 ))
81+ matmul = create_node (
82+ graph , torch .ops .aten .mm .default , (mat1 , mat2 ), origin = node
83+ )
8284 set_new_meta_val (matmul )
8385
8486 if beta == 1 :
@@ -90,7 +92,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
9092 f"{ node .name } _beta_zeros" ,
9193 )
9294 else :
93- bias = graph .call_function (torch .ops .aten .mul .Tensor , (input , beta ))
95+ bias = create_node (
96+ graph , torch .ops .aten .mul .Tensor , (input , beta ), origin = node
97+ )
9498
9599 if alpha == 1 :
96100 scaled_matmul : torch .fx .Node | torch .Tensor = matmul
@@ -101,12 +105,12 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
101105 f"{ node .name } _alpha_zeros" ,
102106 )
103107 else :
104- scaled_matmul = graph . call_function (
105- torch .ops .aten .mul .Tensor , (matmul , alpha )
108+ scaled_matmul = create_node (
109+ graph , torch .ops .aten .mul .Tensor , (matmul , alpha ), origin = node
106110 )
107111
108- result = graph . call_function (
109- torch .ops .aten .add .Tensor , (bias , scaled_matmul )
112+ result = create_node (
113+ graph , torch .ops .aten .add .Tensor , (bias , scaled_matmul )
110114 )
111115
112116 node .replace_all_uses_with (result , propagate_meta = True )
0 commit comments