Skip to content

Commit 9f01b51

Browse files
removed the import dependency by moving to shared file
1 parent 084aeb6 commit 9f01b51

File tree

2 files changed

+115
-1
lines changed

2 files changed

+115
-1
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms...
2+
# BSD 3-Clause
3+
4+
# This module hosts shared pre-grad passes used by multiple backends.
5+
# NOTE: copied from x86 to avoid backend-to-backend dependencies.
6+
7+
8+
import torch
9+
from torch.fx.node import map_arg
10+
11+
aten = torch.ops.aten
12+
prims = torch.ops.prims
13+
quantized_decomposed = torch.ops.quantized_decomposed
14+
quantized = torch.ops.quantized
15+
16+
_PER_TENSOR_QUANTIZE_OPS = [
17+
quantized_decomposed.quantize_per_tensor.default,
18+
quantized_decomposed.quantize_per_tensor.tensor,
19+
]
20+
21+
_VIEW_OPS = [
22+
aten.transpose.int,
23+
aten.permute.default,
24+
aten.view.default,
25+
]
26+
27+
28+
# ----taken from torchao/quantization/pt2e/inductor_passes/x86.py------
29+
def quant_lift_up(module_graph: torch.fx.graph.Graph):
30+
"""
31+
Lift up the quant node before view like nodes. It can benefit performance
32+
of Attention like block. For example, we have the pattern as:
33+
34+
DQ
35+
DQ LINEAR
36+
LINEAR VIEW
37+
VIEW PERMUTE
38+
PERMUTE TRANSPOSE
39+
Q Q
40+
DQ DQ
41+
Matmul
42+
DIV
43+
ADD
44+
SOFTMAX
45+
46+
We want to lift up the the quant nodes from matmul before view like nodes
47+
as the output of Linear node.
48+
49+
DQ
50+
DQ LINEAR
51+
LINEAR Q
52+
Q VIEW
53+
VIEW PERMUTE
54+
PERMUTE TRANSPOSE
55+
DQ DQ
56+
Matmul
57+
DIV
58+
ADD
59+
SOFTMAX
60+
61+
It produces a DQ->LINEAR->Q pattern which can be fused by backend.
62+
"""
63+
64+
def is_view_op(node):
65+
return node.op == "call_function" and node.target in _VIEW_OPS
66+
67+
for node in module_graph.nodes:
68+
# <TODO> Leslie: Here we verify that the quant node has exactly
69+
# one input FX node, with constant scalar value for scale and zero point.
70+
# For the case input of quant node has more than one input FX nodes,
71+
# extend the implementation to lift up all the connected nodes
72+
# before the view nodes to keep the topological order.
73+
if (
74+
node.op == "call_function"
75+
and node.target in _PER_TENSOR_QUANTIZE_OPS
76+
and len(node.all_input_nodes) == 1
77+
and is_view_op(node.all_input_nodes[0])
78+
):
79+
quant_node = node
80+
input_node_of_quant = quant_node.args[0]
81+
82+
# Check the nodes along lift up path has only 1 user node
83+
# Propagate view like node to find where to insert the new quant node
84+
could_lift_up = True
85+
current_node = quant_node
86+
input_node = current_node.args[0]
87+
while is_view_op(input_node):
88+
if len(input_node.users) != 1:
89+
could_lift_up = False
90+
break
91+
current_node = input_node
92+
input_node = current_node.args[0]
93+
94+
# Further check the input node of the first view node has only 1 user node
95+
if could_lift_up and len(input_node.users) == 1:
96+
# Replace dequant's input from quant to quant's input
97+
quant_node.replace_all_uses_with(input_node_of_quant)
98+
# Insert the new quant node
99+
with module_graph.inserting_before(current_node):
100+
new_quant_node = module_graph.node_copy(quant_node)
101+
input_node.replace_all_uses_with(new_quant_node)
102+
103+
# Update inputs of new_quant_node
104+
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
105+
if n == input_node_of_quant:
106+
return input_node
107+
else:
108+
return n
109+
110+
new_args = map_arg(new_quant_node.args, maybe_replace_node)
111+
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
112+
new_quant_node.args = new_args # type: ignore[assignment]
113+
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
114+
module_graph.erase_node(quant_node)

torchao/quantization/pt2e/quantizer/arm_inductor_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.quantization.pt2e.inductor_passes.arm import (
2727
_register_quantization_weight_pack_pass,
2828
)
29-
from torchao.quantization.pt2e.inductor_passes.x86 import (
29+
from torchao.quantization.pt2e.inductor_passes.utils import (
3030
quant_lift_up,
3131
)
3232
from torchao.quantization.pt2e.observer import (

0 commit comments

Comments
 (0)