合并复合 Relay 算子

合并复合 Relay 算子#

import tvm
from tvm import relay, tir
from tvm.relay.dataflow_pattern import TuplePattern, TupleGetItemPattern, is_op, wildcard
from tvm.relay.testing import run_opt_pass

def check_result(pattern_table, graph, expected_graph, import_prelude=False):
    """检查合并复合结果的实用函数。"""
    result = run_opt_pass(
        graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude
    )
    assert not relay.analysis.free_vars(result), "Found free vars in the result graph: {0}".format(
        str(result)
    )
    expected = run_opt_pass(expected_graph, relay.transform.InferType())
    assert tvm.ir.structural_equal(
        result, expected, map_free_vars=True
    ), "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected))
from pattern import *
def before():
    a = relay.var("a", shape=(10, 10))
    b = relay.var("b", shape=(10, 10))
    c = relay.var("c", shape=(10, 10))
    add_node = relay.add(a, b)
    sub_node = relay.subtract(a, b)
    mul_node = relay.multiply(add_node, sub_node)
    add_node_2 = relay.add(c, mul_node)
    sub_node_2 = relay.subtract(c, mul_node)
    mul_node_2 = relay.multiply(add_node_2, sub_node_2)
    r = relay.nn.relu(mul_node_2)
    return relay.Function([a, b, c], r)
pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())]
mod = tvm.IRModule.from_expr(before())
run_mod = relay.transform.MergeComposite(pattern_table)(mod)
mod.show()
run_mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
def @main(%a: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %b: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %c: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %4 = fn (%FunctionVar_1_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_1_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %2 = add(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    %3 = subtract(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%2, %3) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %5 = %4(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %6 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    %1 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%0, %1) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %7 = %6(%c, %5) /* ty=Tensor[(10, 10), float32] */;
  nn.relu(%7) /* ty=Tensor[(10, 10), float32] */
}