|
| 1 | +# Copyright (c) Meta Platforms... |
| 2 | +# BSD 3-Clause |
| 3 | + |
| 4 | +# This module hosts shared pre-grad passes used by multiple backends. |
| 5 | +# NOTE: copied from x86 to avoid backend-to-backend dependencies. |
| 6 | + |
| 7 | + |
| 8 | +import torch |
| 9 | +from torch.fx.node import map_arg |
| 10 | + |
| 11 | +aten = torch.ops.aten |
| 12 | +prims = torch.ops.prims |
| 13 | +quantized_decomposed = torch.ops.quantized_decomposed |
| 14 | +quantized = torch.ops.quantized |
| 15 | + |
| 16 | +_PER_TENSOR_QUANTIZE_OPS = [ |
| 17 | + quantized_decomposed.quantize_per_tensor.default, |
| 18 | + quantized_decomposed.quantize_per_tensor.tensor, |
| 19 | +] |
| 20 | + |
| 21 | +_VIEW_OPS = [ |
| 22 | + aten.transpose.int, |
| 23 | + aten.permute.default, |
| 24 | + aten.view.default, |
| 25 | +] |
| 26 | + |
| 27 | + |
| 28 | +# ----taken from torchao/quantization/pt2e/inductor_passes/x86.py------ |
| 29 | +def quant_lift_up(module_graph: torch.fx.graph.Graph): |
| 30 | + """ |
| 31 | + Lift up the quant node before view like nodes. It can benefit performance |
| 32 | + of Attention like block. For example, we have the pattern as: |
| 33 | +
|
| 34 | + DQ |
| 35 | + DQ LINEAR |
| 36 | + LINEAR VIEW |
| 37 | + VIEW PERMUTE |
| 38 | + PERMUTE TRANSPOSE |
| 39 | + Q Q |
| 40 | + DQ DQ |
| 41 | + Matmul |
| 42 | + DIV |
| 43 | + ADD |
| 44 | + SOFTMAX |
| 45 | +
|
| 46 | + We want to lift up the the quant nodes from matmul before view like nodes |
| 47 | + as the output of Linear node. |
| 48 | +
|
| 49 | + DQ |
| 50 | + DQ LINEAR |
| 51 | + LINEAR Q |
| 52 | + Q VIEW |
| 53 | + VIEW PERMUTE |
| 54 | + PERMUTE TRANSPOSE |
| 55 | + DQ DQ |
| 56 | + Matmul |
| 57 | + DIV |
| 58 | + ADD |
| 59 | + SOFTMAX |
| 60 | +
|
| 61 | + It produces a DQ->LINEAR->Q pattern which can be fused by backend. |
| 62 | + """ |
| 63 | + |
| 64 | + def is_view_op(node): |
| 65 | + return node.op == "call_function" and node.target in _VIEW_OPS |
| 66 | + |
| 67 | + for node in module_graph.nodes: |
| 68 | + # <TODO> Leslie: Here we verify that the quant node has exactly |
| 69 | + # one input FX node, with constant scalar value for scale and zero point. |
| 70 | + # For the case input of quant node has more than one input FX nodes, |
| 71 | + # extend the implementation to lift up all the connected nodes |
| 72 | + # before the view nodes to keep the topological order. |
| 73 | + if ( |
| 74 | + node.op == "call_function" |
| 75 | + and node.target in _PER_TENSOR_QUANTIZE_OPS |
| 76 | + and len(node.all_input_nodes) == 1 |
| 77 | + and is_view_op(node.all_input_nodes[0]) |
| 78 | + ): |
| 79 | + quant_node = node |
| 80 | + input_node_of_quant = quant_node.args[0] |
| 81 | + |
| 82 | + # Check the nodes along lift up path has only 1 user node |
| 83 | + # Propagate view like node to find where to insert the new quant node |
| 84 | + could_lift_up = True |
| 85 | + current_node = quant_node |
| 86 | + input_node = current_node.args[0] |
| 87 | + while is_view_op(input_node): |
| 88 | + if len(input_node.users) != 1: |
| 89 | + could_lift_up = False |
| 90 | + break |
| 91 | + current_node = input_node |
| 92 | + input_node = current_node.args[0] |
| 93 | + |
| 94 | + # Further check the input node of the first view node has only 1 user node |
| 95 | + if could_lift_up and len(input_node.users) == 1: |
| 96 | + # Replace dequant's input from quant to quant's input |
| 97 | + quant_node.replace_all_uses_with(input_node_of_quant) |
| 98 | + # Insert the new quant node |
| 99 | + with module_graph.inserting_before(current_node): |
| 100 | + new_quant_node = module_graph.node_copy(quant_node) |
| 101 | + input_node.replace_all_uses_with(new_quant_node) |
| 102 | + |
| 103 | + # Update inputs of new_quant_node |
| 104 | + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: |
| 105 | + if n == input_node_of_quant: |
| 106 | + return input_node |
| 107 | + else: |
| 108 | + return n |
| 109 | + |
| 110 | + new_args = map_arg(new_quant_node.args, maybe_replace_node) |
| 111 | + new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) |
| 112 | + new_quant_node.args = new_args # type: ignore[assignment] |
| 113 | + new_quant_node.kwargs = new_kwargs # type: ignore[assignment] |
| 114 | + module_graph.erase_node(quant_node) |
0 commit comments