量化 QPartitionExpr

量化 QPartitionExpr#

下面以表达式 \(f(x, y) = (x + y)(x -y)\) 为例展示。

import set_env
import tvm
from tvm import relay

x = relay.var("x", dtype="float32", shape=(10,))
y = relay.var("y", dtype="float32", shape=(10,))
z1 = x + y
z2 = x - y
z3 = z1 * z2
z4 = relay.exp(z3)
mod = tvm.IRModule.from_expr(z4)
mod.show()
def @main(%x: Tensor[(10), float32], %y: Tensor[(10), float32]) {
  %0 = add(%x, %y);
  %1 = subtract(%x, %y);
  %2 = multiply(%0, %1);
  exp(%2)
}

自定义分区#

from tvm.relay.quantize._partition import (
    register_partition_function,
    QPartitionExpr,
    partition_expr_check
)
from tvm.relay.dataflow_pattern import is_constant, is_op, wildcard, is_var
from tvm.relay import Call
from tvm.relay.function import Function, FunctionWithFields
@tvm.relay.transform.function_pass(opt_level=1)
class MergeGraphTransform:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.nodes = []

    def transform_function(self, func, mod, ctx):
        obj = self
        class Replace(tvm.relay.ExprMutator):
            def visit_function(self, fn):
                new_params = [self.visit(x) for x in fn.params]
                new_body = self.visit(fn.body)
                new_body = QPartitionExpr(new_body).realize()
                if new_params == list(fn.params) and new_body == fn.body:
                    new_fn =  fn
                else:
                    new_fn = FunctionWithFields(fn, list(new_params), new_body)
                obj.nodes.append(new_fn)
                return new_fn
        return Replace().visit(func)

def make_add_subtract_multiply_pattern():
    """查找模式
        (x + y)(x - y)
    """
    x = is_var()
    y = is_var()
    node1 = is_op("add")(x, y)
    node2 = is_op("subtract")(x, y)
    node = is_op("multiply")(node1, node2)
    return node
compiler_name = "ccompiler"
pattern_table = [
    (f"{compiler_name}.add_subtract_multiply", make_add_subtract_multiply_pattern()),
]
merge_passes = tvm.transform.Sequential([
    relay.transform.MergeComposite(pattern_table),
    # relay.transform.AnnotateTarget([compiler_name]),
    relay.transform.PartitionGraph(),
    # relay.transform.ToANormalForm()
])
run_mod = merge_passes(mod)
run_mod.show()
def @main(%x: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %y: Tensor[(10), float32] /* ty=Tensor[(10), float32] */) -> Tensor[(10), float32] {
  %2 = fn (%FunctionVar_0_0: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %FunctionVar_0_1: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="ccompiler.add_subtract_multiply") -> Tensor[(10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
    %1 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
    multiply(%0, %1) /* ty=Tensor[(10), float32] */
  } /* ty=fn (Tensor[(10), float32], Tensor[(10), float32]) -> Tensor[(10), float32] */;
  %3 = %2(%x, %y) /* ty=Tensor[(10), float32] */;
  exp(%3) /* ty=Tensor[(10), float32] */
}
transform = MergeGraphTransform()
run_mod = transform(run_mod)
run_mod.show()
def @main(%x: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %y: Tensor[(10), float32] /* ty=Tensor[(10), float32] */) -> Tensor[(10), float32] {
  %4 = fn (%FunctionVar_0_0: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %FunctionVar_0_1: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="ccompiler.add_subtract_multiply") -> Tensor[(10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
    %1 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10), float32] */;
    %2 = multiply(%0, %1) /* ty=Tensor[(10), float32] */;
    %3 = annotation.cast_hint(%2, dtype="int8") /* ty=Tensor[(10), float32] */;
    annotation.stop_fusion(%3) /* ty=Tensor[(10), float32] */
  } /* ty=fn (Tensor[(10), float32], Tensor[(10), float32]) -> Tensor[(10), float32] */;
  %5 = %4(%x, %y) /* ty=Tensor[(10), float32] */;
  %6 = exp(%5) /* ty=Tensor[(10), float32] */;
  %7 = annotation.cast_hint(%6, dtype="int8") /* ty=Tensor[(10), float32] */;
  annotation.stop_fusion(%7) /* ty=Tensor[(10), float32] */
}

从数学角度来看,上述问题可以化简为 \(f(x, y) = x^2 - y^2\)

from tvm.relay.dataflow_pattern import DFPatternCallback

class MergeGraphCallback(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        self.pattern = make_add_subtract_multiply_pattern()

    def callback(self, pre, post, node_map):
        x = post.args[0].args[0] * post.args[0].args[0]
        y = post.args[0].args[1] * post.args[0].args[1]
        return x - y
from tvm.relay.dataflow_pattern import rewrite

rewrite(MergeGraphCallback(), relay.transform.DefuseOps()(run_mod)["main"])
fn (%x: Tensor[(10), float32] /* ty=Tensor[(10), float32] */, %y: Tensor[(10), float32] /* ty=Tensor[(10), float32] */) -> Tensor[(10), float32] {
  %0 = multiply(%x, %x);
  %1 = multiply(%y, %y);
  %2 = subtract(%0, %1);
  %3 = annotation.cast_hint(%2, dtype="int8") /* ty=Tensor[(10), float32] */;
  %4 = annotation.stop_fusion(%3) /* ty=Tensor[(10), float32] */;
  %5 = exp(%4) /* ty=Tensor[(10), float32] */;
  %6 = annotation.cast_hint(%5, dtype="int8") /* ty=Tensor[(10), float32] */;
  annotation.stop_fusion(%6) /* ty=Tensor[(10), float32] */
} /* ty=fn (Tensor[(10), float32], Tensor[(10), float32]) -> Tensor[(10), float32] */