编译外部库#

import set_env
/media/pc/data/lxw/ai/tvm

加载库:

import numpy as np
import tvm
from tvm import relay
from tvm.relay import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend.runtime import Runtime
from tvm.relay.backend import te_compiler
from tvm.contrib.utils import tempdir

def update_lib(lib, source_dir="/media/pc/data/lxw/ai/tvm"):
    kwargs = {
        "options" : [
            "-O2", "-std=c++17", 
            f"-I{source_dir}/src/runtime/contrib", 
            f"-I{source_dir}/include",
            f"-I{source_dir}/3rdparty/dlpack/include",
            f"-I{source_dir}/3rdparty/dmlc-core/include",
        ]
    }
    tmp_path = tempdir()
    lib_name = "lib.so"
    lib_path = tmp_path.relpath(lib_name)
    lib.export_library(lib_path, fcompile=False, **kwargs)
    lib = tvm.runtime.load_module(lib_path)
    return lib

def check_result(
    mod,
    map_inputs,
    out_shape,
    result,
    tol=1e-5,
    target="llvm",
    device=tvm.cpu(),
    params=None,
    runtime=Runtime("cpp"),
):
    def check_vm_result():
        te_compiler.get().clear()
        with tvm.transform.PassContext(opt_level=3):
            exe = relay.vm.compile(mod, target=target, params=params)
        code, lib = exe.save()
        lib = update_lib(lib)
        exe = tvm.runtime.vm.Executable.load_exec(code, lib)
        vm = tvm.runtime.vm.VirtualMachine(exe, device)
        outs = vm.run(**map_inputs)
        outs = outs if isinstance(outs, tvm.runtime.container.ADT) else [outs]
        results = result if isinstance(result, list) else [result]
        for out, ref in zip(outs, results):
            np.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol)
    check_vm_result()

z = x + y 为例子说明:

x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
z = x + y
f = relay.Function([x, y], z)
mod = tvm.IRModule()
mod["main"] = f
mod.show()
def @main(%x: Tensor[(8, 8), float32], %y: Tensor[(8, 8), float32]) {
  add(%x, %y)
}

编写简单的注解函数:

@relay.transform.function_pass(opt_level=0)
class MyAnnotator:
    def transform_function(self, func, mod, dev):
        class Annotator(ExprMutator):
            def visit_call(self, call):
                new_args = []
                for arg in call.args:
                    ann = compiler_begin(self.visit(arg), "ccompiler")
                    new_args.append(ann)
                new_call = relay.Call(call.op, new_args)
                return compiler_end(new_call, "ccompiler")

        return Annotator().visit(func)

+ 的输入输入和输出进行注解:

mod = MyAnnotator()(mod)
mod.show()
def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(8, 8), float32] {
  %0 = annotation.compiler_begin(%x, compiler="ccompiler") /* ty=Tensor[(8, 8), float32] */;
  %1 = annotation.compiler_begin(%y, compiler="ccompiler") /* ty=Tensor[(8, 8), float32] */;
  %2 = add(%0, %1) /* ty=Tensor[(8, 8), float32] */;
  annotation.compiler_end(%2, compiler="ccompiler") /* ty=Tensor[(8, 8), float32] */
}

使用 PartitionGraph 分割计算图:

mod = relay.transform.PartitionGraph()(mod)
mod.show()
def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(8, 8), float32] {
  @tvmgen_default_ccompiler_main_0(%x, %y) /* ty=Tensor[(8, 8), float32] */
}

def @tvmgen_default_ccompiler_main_0(%ccompiler_0_i0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %ccompiler_0_i1: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Compiler="ccompiler", Primitive=1, Inline=1, global_symbol="tvmgen_default_ccompiler_main_0") -> Tensor[(8, 8), float32] {
  add(%ccompiler_0_i0, %ccompiler_0_i1) /* ty=Tensor[(8, 8), float32] */
}

验证结果一致性:

x_data = np.random.rand(8, 8).astype("float32")
y_data = np.random.rand(8, 8).astype("float32")
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)
[15:56:02] /media/pc/data/lxw/ai/tvm/src/relay/backend/vm/compiler.cc:1199: All lowered functions have been build by BYOC -- generating an empty TVM module

注解白名单#

# 利用 pass 管理器编写简单的注释器白名单
@relay.transform.function_pass(opt_level=0)
class AllowedListAnnotator:
    def __init__(self, op_list, compiler):
        assert isinstance(op_list, (list, tuple, set))
        self.op_list = op_list
        self.compiler = compiler

    def transform_function(self, func, mod, dev):

        annotator = self

        class Annotator(tvm.relay.ExprMutator):
            def visit_call(self, call):
                op_name = call.op.name
                if op_name in annotator.op_list:
                    new_args = []
                    for arg in call.args:
                        ann = compiler_begin(super().visit(arg), annotator.compiler)
                        new_args.append(ann)
                    new_call = relay.Call(call.op, new_args, call.attrs, call.type_args)
                    return compiler_end(new_call, annotator.compiler)
                else:
                    return super().visit_call(call)

        return Annotator().visit(func)
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f = relay.Function([x, y], concat)
mod = tvm.IRModule()
mod["main"] = f
mod.show()
def @main(%x: Tensor[(8, 8), float32], %y: Tensor[(8, 8), float32]) {
  %0 = add(%x, %y);
  %1 = log(%0);
  %2 = exp(%0);
  %3 = (%1, %2);
  concatenate(%3)
}
def expected():
    mod = tvm.IRModule()
    x = relay.var("x", shape=(8, 8))
    y = relay.var("y", shape=(8, 8))
    x0 = relay.var("x0", shape=(8, 8))
    y0 = relay.var("y0", shape=(8, 8))
    add = x0 + y0
    # Function that uses C compiler
    func = relay.Function([x0, y0], add)
    func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
    glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
    mod[glb_0] = func
    add_call = relay.Call(glb_0, [x, y])
    # Function that uses default compiler. Ops are fused in this function.
    p0 = relay.var("p0", shape=(8, 8))
    log = relay.log(p0)
    exp = relay.exp(p0)
    concat = relay.concatenate([log, exp], axis=0)
    fused_func = relay.Function([p0], concat)
    fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
    fused_call = relay.Call(fused_func, [add_call])
    main = relay.Function([x, y], fused_call)
    mod["main"] = main
    mod = transform.InferType()(mod)
    return mod
def set_func_attr(func, compile_name, symbol_name):
    func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
    func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
    func = func.with_attr("Compiler", compile_name)
    func = func.with_attr("global_symbol", symbol_name)
    return func
def expected():
    mod = tvm.IRModule()
    x = relay.var("x", shape=(8, 8))
    y = relay.var("y", shape=(8, 8))
    x0 = relay.var("x0", shape=(8, 8))
    y0 = relay.var("y0", shape=(8, 8))
    add = x0 + y0
    # Function that uses C compiler
    func = relay.Function([x0, y0], add)
    func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
    glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
    mod[glb_0] = func
    add_call = relay.Call(glb_0, [x, y])
    # Function that uses default compiler. Ops are fused in this function.
    p0 = relay.var("p0", shape=(8, 8))
    log = relay.log(p0)
    exp = relay.exp(p0)
    concat = relay.concatenate([log, exp], axis=0)
    fused_func = relay.Function([p0], concat)
    fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
    fused_call = relay.Call(fused_func, [add_call])
    main = relay.Function([x, y], fused_call)
    mod["main"] = main
    mod = relay.transform.InferType()(mod)
    return mod

x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
exp = relay.exp(add)
concat = relay.concatenate([log, exp], axis=0)
f = relay.Function([x, y], concat)
mod = tvm.IRModule()
mod["main"] = f
mod = AllowedListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
mod = relay.transform.PartitionGraph()(mod)
fused_mod = relay.transform.FuseOps(2)(mod)
expected_mod = expected()
assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)

x_data = np.random.rand(8, 8).astype("float32")
y_data = np.random.rand(8, 8).astype("float32")
np_add = x_data + y_data
res = np.concatenate([np.log(np_add), np.exp(np_add)])
check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
expected_mod.show()
def @main(%x: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */) -> Tensor[(16, 8), float32] {
  %3 = @tvmgen_default_ccompiler_main_0(%x, %y) /* ty=Tensor[(8, 8), float32] */;
  %4 = fn (%p0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Primitive=1) -> Tensor[(16, 8), float32] {
    %0 = log(%p0) /* ty=Tensor[(8, 8), float32] */;
    %1 = exp(%p0) /* ty=Tensor[(8, 8), float32] */;
    %2 = (%0, %1) /* ty=(Tensor[(8, 8), float32], Tensor[(8, 8), float32]) */;
    concatenate(%2) /* ty=Tensor[(16, 8), float32] */
  } /* ty=fn (Tensor[(8, 8), float32]) -> Tensor[(16, 8), float32] */;
  %4(%3) /* ty=Tensor[(16, 8), float32] */
}

def @tvmgen_default_ccompiler_main_0(%x0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, %y0: Tensor[(8, 8), float32] /* ty=Tensor[(8, 8), float32] */, Primitive=1, Inline=1, Compiler="ccompiler", global_symbol="tvmgen_default_ccompiler_main_0") -> Tensor[(8, 8), float32] {
  add(%x0, %y0) /* ty=Tensor[(8, 8), float32] */
}

其他外部编译器支持#

def test_extern_compiler_sanitized_ops():
    def expected():
        mod = tvm.IRModule()
        x = relay.var("x", shape=(8, 8))
        y = relay.var("y", shape=(8, 8))
        x0 = relay.var("x0", shape=(8, 8))
        y0 = relay.var("y0", shape=(8, 8))
        add = x0 + y0
        # Function that uses C compiler
        func = relay.Function([x0, y0], add)
        func = set_func_attr(func, "unsanitary-name++", "tvmgen_default_unsanitary_name___main_0")
        glb_0 = relay.GlobalVar("tvmgen_default_unsanitary_name___main_0")
        mod[glb_0] = func
        add_call = relay.Call(glb_0, [x, y])
        # Function that uses default compiler. Ops are fused in this function.
        p0 = relay.var("p0", shape=(8, 8))
        log = relay.log(p0)
        exp = relay.exp(p0)
        concat = relay.concatenate([log, exp], axis=0)
        fused_func = relay.Function([p0], concat)
        fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
        fused_call = relay.Call(fused_func, [add_call])
        main = relay.Function([x, y], fused_call)
        mod["main"] = main
        mod = transform.InferType()(mod)
        return mod

    x = relay.var("x", shape=(8, 8))
    y = relay.var("y", shape=(8, 8))
    add = x + y
    log = relay.log(add)
    exp = relay.exp(add)
    concat = relay.concatenate([log, exp], axis=0)
    f = relay.Function([x, y], concat)
    mod = tvm.IRModule()
    mod["main"] = f
    mod = AllowedListAnnotator(["add", "subtract", "multiply"], "unsanitary-name++")(mod)
    mod = transform.PartitionGraph()(mod)
    fused_mod = transform.FuseOps(2)(mod)
    expected_mod = expected()
    assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)


def test_extern_ccompiler_multiple_functions():
    def expected():
        mod = tvm.IRModule()
        x = relay.var("x", shape=(8, 8))
        y = relay.var("y", shape=(8, 8))
        x0 = relay.var("x0", shape=(8, 8))
        y0 = relay.var("y0", shape=(8, 8))
        add = x0 + y0
        # Function that uses C compiler
        func = relay.Function([x0, y0], add)
        func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0")
        glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0")
        mod[glb_0] = func
        add_call = relay.Call(glb_0, [x, y])
        # Function that uses default compiler. Ops are fused in this function.
        p0 = relay.var("p0", shape=(8, 8))
        log = relay.log(p0)
        exp = relay.exp(p0)
        concat = relay.concatenate([log, exp], axis=0)
        fused_func = relay.Function([p0], concat)
        fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
        fused_call = relay.Call(fused_func, [add_call])
        main = relay.Function([x, y], fused_call)
        mod["main"] = main
        # define the second one
        a = relay.var("a", shape=(16, 16))
        b = relay.var("b", shape=(16, 16))
        a0 = relay.var("a0", shape=(16, 16))
        b0 = relay.var("b0", shape=(16, 16))
        add = a0 + b0
        # Function that uses C compiler
        func = relay.Function([a0, b0], add)
        func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_subfunction_0")
        glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_subfunction_0")
        mod[glb_0] = func
        add_call = relay.Call(glb_0, [a, b])
        # Function that uses default compiler. Ops are fused in this function.
        p0 = relay.var("p0", shape=(16, 16))
        log = relay.log(p0)
        exp = relay.exp(p0)
        concat = relay.concatenate([log, exp], axis=0)
        fused_func = relay.Function([p0], concat)
        fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
        fused_call = relay.Call(fused_func, [add_call])
        sunfunction = relay.Function([a, b], fused_call)
        mod["subfunction"] = sunfunction
        mod = transform.InferType()(mod)
        return mod

    x = relay.var("x", shape=(8, 8))
    y = relay.var("y", shape=(8, 8))
    add = x + y
    log = relay.log(add)
    exp = relay.exp(add)
    concat = relay.concatenate([log, exp], axis=0)
    f = relay.Function([x, y], concat)
    mod = tvm.IRModule()
    mod["main"] = f
    # define second function
    a = relay.var("a", shape=(16, 16))
    b = relay.var("b", shape=(16, 16))
    add = a + b
    log = relay.log(add)
    exp = relay.exp(add)
    concat = relay.concatenate([log, exp], axis=0)
    f2 = relay.Function([a, b], concat)
    mod["subfunction"] = f2
    mod = AllowedListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
    mod = transform.PartitionGraph()(mod)

    fused_mod = transform.FuseOps(2)(mod)
    expected_mod = expected()
    assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)

    x_data = np.random.rand(8, 8).astype("float32")
    y_data = np.random.rand(8, 8).astype("float32")
    np_add = x_data + y_data
    res = np.concatenate([np.log(np_add), np.exp(np_add)])
    check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)


def test_extern_ccompiler():
    x = relay.var("x", shape=(2, 2))
    y = relay.var("y", shape=(2, 2))
    z = x + x
    p = y * y
    f = relay.Function([x, y], p - z)
    x_data = np.random.rand(2, 2).astype("float32")
    y_data = np.random.rand(2, 2).astype("float32")
    mod = tvm.IRModule()
    mod["main"] = f
    mod = AllowedListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod)
    mod = transform.PartitionGraph()(mod)

    check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))
class WholeGraphAnnotator(ExprMutator):
    """
    An annotator that creates a compiler for an entire graph.
    """

    def __init__(self, compiler):
        super().__init__()
        self.compiler = compiler
        self.last_call = True

    def visit_call(self, call):
        curr_last = self.last_call
        self.last_call = False

        params = []
        for arg in call.args:
            param = super().visit(arg)
            if isinstance(param, relay.expr.Var):
                param = compiler_begin(param, self.compiler)
            params.append(param)

        new_call = relay.Call(call.op, params, call.attrs)
        if curr_last:
            new_call = compiler_end(new_call, self.compiler)
        return new_call
dtype = "float32"
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
def get_func():
    data = relay.var("data", shape=(ishape), dtype=dtype)
    weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
    depthwise_conv2d_1 = relay.nn.conv2d(
        data, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
    )
    depthwise_conv2d_2 = relay.nn.conv2d(
        depthwise_conv2d_1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
    )
    out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)

    return relay.Function([data, weight1], out)
func = get_func()
mod = tvm.IRModule()
mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = relay.transform.PartitionGraph()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
def @main(%data: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %weight1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */) -> Tensor[(1, 32, 14, 14), float32] {
  @tvmgen_default_dnnl_main_0(%data, %weight1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}

def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_0_i1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_0") -> Tensor[(1, 32, 14, 14), float32] {
  %0 = nn.conv2d(%dnnl_0_i0, %dnnl_0_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */;
  %1 = nn.conv2d(%0, %dnnl_0_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */;
  add(%0, %1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}