注解目标设备#

import os
import sys
import numpy as np

import tvm
import tvm.relay.testing
from tvm.relay import transform
from tvm import relay
from tvm import runtime
from tvm.contrib import utils

注解 DNNL#

def annotated(dtype, ishape, w1shape):
    data = relay.var("data", shape=(ishape), dtype=dtype)
    weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
    depthwise_conv2d_1 = relay.nn.conv2d(
        data, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
    )
    depthwise_conv2d_2 = relay.nn.conv2d(
        depthwise_conv2d_1, weight1, kernel_size=(3, 3), padding=(1, 1), groups=32
    )
    out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)

    f = relay.Function([data, weight1], out)

    mod = tvm.IRModule.from_expr(f)
    return mod
dtype = "float32"
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)

mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod)
mod = relay.transform.InferType()(mod)
mod = transform.PartitionGraph()(mod)
print(mod)
def @main(%data: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %weight1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */) -> Tensor[(1, 32, 14, 14), float32] {
  %0 = @tvmgen_default_dnnl_main_0(%data, %weight1) /* ty=Tensor[(1, 32, 14, 14), float32] */;
  %1 = @tvmgen_default_dnnl_main_3(%0, %weight1) /* ty=Tensor[(1, 32, 14, 14), float32] */;
  @tvmgen_default_dnnl_main_2(%0, %1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}

def @tvmgen_default_dnnl_main_0(%dnnl_0_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_0_i1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_0") -> Tensor[(1, 32, 14, 14), float32] {
  nn.conv2d(%dnnl_0_i0, %dnnl_0_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */
}

def @tvmgen_default_dnnl_main_2(%dnnl_2_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_2_i1: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_2") -> Tensor[(1, 32, 14, 14), float32] {
  add(%dnnl_2_i0, %dnnl_2_i1) /* ty=Tensor[(1, 32, 14, 14), float32] */
}

def @tvmgen_default_dnnl_main_3(%dnnl_3_i0: Tensor[(1, 32, 14, 14), float32] /* ty=Tensor[(1, 32, 14, 14), float32] */, %dnnl_3_i1: Tensor[(32, 1, 3, 3), float32] /* ty=Tensor[(32, 1, 3, 3), float32] */, Compiler="dnnl", Primitive=1, Inline=1, global_symbol="tvmgen_default_dnnl_main_3") -> Tensor[(1, 32, 14, 14), float32] {
  nn.conv2d(%dnnl_3_i0, %dnnl_3_i1, padding=[1, 1, 1, 1], groups=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 14, 14), float32] */
}

注解多后端设备#

@tvm.ir.register_op_attr("nn.relu", "target.test")
def relu(expr):  # pylint: disable=unused-variable
    return True
def before():
    x = relay.var("x", shape=(10, 10))
    r = relay.nn.relu(x)
    a_1 = relay.abs(r)
    a_2 = relay.abs(r)
    out = relay.add(a_1, a_2)
    f = relay.Function([x], out)
    return tvm.IRModule.from_expr(f)

result = transform.AnnotateTarget("test")(before())
print(result)
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %0 = annotation.compiler_begin(%x, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(10, 10), float32] */;
  %2 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %3 = annotation.compiler_begin(%2, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %4 = abs(%3) /* ty=Tensor[(10, 10), float32] */;
  %5 = annotation.compiler_end(%4, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %6 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %7 = annotation.compiler_begin(%6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %8 = abs(%7) /* ty=Tensor[(10, 10), float32] */;
  %9 = annotation.compiler_end(%8, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %10 = annotation.compiler_begin(%5, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %11 = annotation.compiler_begin(%9, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %12 = add(%10, %11) /* ty=Tensor[(10, 10), float32] */;
  annotation.compiler_end(%12, compiler="default") /* ty=Tensor[(10, 10), float32] */
}
target = "test_type_propagation"

@tvm.ir.register_op_attr("nn.relu", "target." + target)
def relu(expr):  # pylint: disable=unused-variable
    return expr.args[0].checked_type.dtype == "float32"

def before():
    x = relay.var("x", shape=(10, 10))
    r = relay.nn.relu(x)
    out = relay.nn.relu(r)
    f = relay.Function([x], out)
    mod = tvm.IRModule.from_expr(f)
    return mod
    
res = transform.AnnotateTarget(target, True)(before())
res
def @main(%x: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %0 = annotation.compiler_begin(%x, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(10, 10), float32] */;
  %2 = annotation.compiler_end(%1, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */;
  %3 = annotation.compiler_begin(%2, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(10, 10), float32] */;
  annotation.compiler_end(%4, compiler="test_type_propagation") /* ty=Tensor[(10, 10), float32] */
}

read & write#

target = "relu"

@tvm.ir.register_op_attr("nn.relu", "target." + target)
def annotate(expr):
    return True

def before():
    ref = relay.expr.RefCreate(relay.const(1.0))
    r = relay.expr.RefWrite(ref, relay.nn.relu(relay.expr.RefRead(ref)))
    return tvm.IRModule.from_expr(r)

def after(annotate_non_call_ops):
        co = relay.const(1.0)
        if annotate_non_call_ops:
            co = relay.annotation.compiler_begin(co, "default")

        ref = relay.expr.RefCreate(co)
        ref1 = ref
        if annotate_non_call_ops:
            ref = relay.annotation.compiler_end(ref, "default")
            ref = relay.annotation.compiler_begin(ref, "default")
            ref1 = relay.annotation.compiler_end(ref1, "default")
            ref1 = relay.annotation.compiler_begin(ref1, "default")

        read = relay.expr.RefRead(ref1)
        if annotate_non_call_ops:
            read = relay.annotation.compiler_end(read, "default")

        beg = relay.annotation.compiler_begin(read, target)
        relu = relay.nn.relu(beg)
        end = relay.annotation.compiler_end(relu, target)

        if annotate_non_call_ops:
            end = relay.annotation.compiler_begin(end, "default")

        r = relay.expr.RefWrite(ref, end)

        if annotate_non_call_ops:
            r = relay.annotation.compiler_end(r, "default")
        return tvm.IRModule.from_expr(r)


result = transform.AnnotateTarget(target)(before())
result
def @main() -> () {
  %0 = annotation.compiler_begin(1f /* ty=float32 */, compiler="default") /* ty=float32 */;
  %1 = ref(%0);
  %2 = annotation.compiler_end(%1, compiler="default") /* ty=ref(float32) */;
  %3 = annotation.compiler_begin(%2, compiler="default") /* ty=ref(float32) */;
  %4 = annotation.compiler_end(%1, compiler="default") /* ty=ref(float32) */;
  %5 = annotation.compiler_begin(%4, compiler="default") /* ty=ref(float32) */;
  %6 = ref_read(%5);
  %7 = annotation.compiler_end(%6, compiler="default") /* ty=float32 */;
  %8 = annotation.compiler_begin(%7, compiler="relu") /* ty=float32 */;
  %9 = nn.relu(%8) /* ty=float32 */;
  %10 = annotation.compiler_end(%9, compiler="relu") /* ty=float32 */;
  %11 = annotation.compiler_begin(%10, compiler="default") /* ty=float32 */;
  %12 = ref_write(%3, %11);
  annotation.compiler_end(%12, compiler="default") /* ty=() */
}
result = transform.AnnotateTarget(target, False)(before())
result
def @main() -> () {
  %0 = ref(1f /* ty=float32 */);
  %1 = ref_read(%0);
  %2 = annotation.compiler_begin(%1, compiler="relu") /* ty=float32 */;
  %3 = nn.relu(%2) /* ty=float32 */;
  %4 = annotation.compiler_end(%3, compiler="relu") /* ty=float32 */;
  ref_write(%0, %4)
}

tuple#

target = "test_tuple"

@tvm.ir.register_op_attr("nn.relu", "target." + target)
def relu(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("concatenate", "target." + target)
def concatenate(expr):  # pylint: disable=unused-variable
    return True
def before():
    x = relay.var("x", shape=(10, 5))
    y = relay.var("y", shape=(10, 5))
    a_1 = relay.nn.relu(x)
    a_2 = relay.nn.relu(y)
    out = relay.concatenate((a_1, a_2), axis=1)
    f = relay.Function([x, y], out)
    mod = tvm.IRModule.from_expr(f)
    return mod
def after(annotate_non_call_ops):
    x = relay.var("x", shape=(10, 5))
    y = relay.var("y", shape=(10, 5))
    cb_1 = relay.annotation.compiler_begin(x, target)
    cb_2 = relay.annotation.compiler_begin(y, target)
    a_1 = relay.nn.relu(cb_1)
    a_2 = relay.nn.relu(cb_2)
    ce_1 = relay.annotation.compiler_end(a_1, target)
    ce_2 = relay.annotation.compiler_end(a_2, target)

    if annotate_non_call_ops:
        cb_3 = relay.annotation.compiler_begin(ce_1, target)
        cb_4 = relay.annotation.compiler_begin(ce_2, target)
        tup = relay.Tuple([cb_3, cb_4])
        ce_3 = relay.annotation.compiler_end(tup, target)
    else:
        ce_3 = relay.Tuple([ce_1, ce_2])

    cb_3 = relay.annotation.compiler_begin(ce_3, target)
    out = relay.op._make.concatenate(cb_3, 1)
    ce_4 = relay.annotation.compiler_end(out, target)
    f = relay.Function([x, y], ce_4)
    mod = tvm.IRModule.from_expr(f)
    return mod

for annotate_non_call_ops in [False, True]:
    result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
    expected = transform.InferType()(after(annotate_non_call_ops))
    assert tvm.ir.structural_equal(expected, result)

composite_function#

def before():
    a = relay.var("a", shape=(10, 10))
    b = relay.var("b", shape=(10, 10))

    # add_relu function
    in_1 = relay.var("in_1", shape=(10, 10))
    in_2 = relay.var("in_2", shape=(10, 10))
    add_node = relay.add(in_1, in_2)
    relu_node = relay.nn.relu(add_node)
    add_relu = relay.Function([in_1, in_2], relu_node)
    add_relu = add_relu.with_attr("Composite", "test.add_relu")

    # merged function
    r = relay.Call(add_relu, [a, b])
    f = relay.Function([a, b], r)
    mod = tvm.IRModule.from_expr(f)
    return mod

def after():
    a = relay.var("a", shape=(10, 10))
    b = relay.var("b", shape=(10, 10))

    # add_relu function
    in_1 = relay.var("in_1", shape=(10, 10))
    in_2 = relay.var("in_2", shape=(10, 10))
    add_node = relay.add(in_1, in_2)
    relu_node = relay.nn.relu(add_node)
    add_relu = relay.Function([in_1, in_2], relu_node)
    add_relu = add_relu.with_attr("Composite", "test.add_relu")

    # merged function
    cb_1 = relay.annotation.compiler_begin(a, "test")
    cb_2 = relay.annotation.compiler_begin(b, "test")
    r = relay.Call(add_relu, [cb_1, cb_2])
    ce_1 = relay.annotation.compiler_end(r, "test")
    f = relay.Function([a, b], ce_1)
    mod = tvm.IRModule.from_expr(f)
    return mod

result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)

double_target#

@tvm.ir.register_op_attr("nn.relu", "target.double.A")
def relu(expr):  # pylint: disable=unused-variable
    return True

def before():
    x = relay.var("x", shape=(10, 5))
    a_1 = relay.nn.relu(x)
    mod = tvm.IRModule.from_expr(a_1)
    return mod

for annotate_non_call_ops in [True, False]:
    mod = before()
    mod1 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod)
    mod2 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod1)
    assert tvm.ir.structural_equal(mod1, mod2)

different_target#

@tvm.ir.register_op_attr("nn.relu", "target.different.A")
def relu(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("add", "target.different.B")
def relu(expr):  # pylint: disable=unused-variable
    return True

def before():
    x = relay.var("x", shape=(10, 5))
    a_1 = relay.nn.relu(x)
    b_1 = relay.add(a_1, a_1)
    mod = tvm.IRModule.from_expr(b_1)
    return mod

for annotate_non_call_ops in [True, False]:
    mod = before()
    mod1 = transform.AnnotateTarget("different.A", annotate_non_call_ops)(mod)
    mod1 = transform.AnnotateTarget("different.B", annotate_non_call_ops)(mod1)
    mod2 = transform.AnnotateTarget(["different.A", "different.B"], annotate_non_call_ops)(mod)
    assert tvm.ir.structural_equal(mod1, mod2)

multiple_runs#

@tvm.ir.register_op_attr("nn.relu", "target.A")
def relu(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("add", "target.B")
def add(expr):  # pylint: disable=unused-variable
    return True

def before():
    x = relay.var("x", shape=(10, 5))
    a_1 = relay.nn.relu(x)
    a_2 = relay.abs(a_1)
    a_3 = relay.nn.relu(a_1)
    out = relay.add(a_2, a_3)

    f = relay.Function([x], out)
    mod = tvm.IRModule.from_expr(f)
    return mod

for annotate_non_call_ops in [True, False]:
    mod = transform.AnnotateTarget("A", annotate_non_call_ops)(before())
    mod = transform.AnnotateTarget("B", annotate_non_call_ops)(mod)
    expected = transform.AnnotateTarget(["A", "B"], annotate_non_call_ops)(before())
    assert tvm.ir.structural_equal(expected, mod)

ends_with_tuple#

trgt = "clip"

@tvm.ir.register_op_attr("clip", "target." + trgt)
def relu(expr):  # pylint: disable=unused-variable
    return True

def get_model(get_item):
    """Return a model"""
    a = relay.var("a", shape=(1, 16, 16, 4), dtype="uint8")
    z = relay.op.clip(a, 0, 255)
    b = relay.op.clip(z, 0, 15)
    c = relay.op.clip(z, 16, 31)
    t = relay.Tuple((c, b))
    tgi = relay.TupleGetItem(t, 1) if get_item else t
    foo = relay.Function([a], tgi)
    return tvm.IRModule.from_expr(tgi)

def get_expected(annotate_non_call_ops, get_item):
    a_ = relay.var("a", shape=(1, 16, 16, 4), dtype="uint8")
    a = relay.annotation.compiler_begin(a_, trgt)
    z = relay.op.clip(a, 0, 255)
    z1 = relay.annotation.compiler_end(z, trgt)
    z1 = relay.annotation.compiler_begin(z1, trgt)
    b = relay.op.clip(z1, 0, 15)
    b = relay.annotation.compiler_end(b, trgt)
    b = relay.annotation.compiler_begin(b, trgt) if annotate_non_call_ops else b
    z2 = relay.annotation.compiler_end(z, trgt)
    z2 = relay.annotation.compiler_begin(z2, trgt)
    c = relay.op.clip(z2, 16, 31)
    c = relay.annotation.compiler_end(c, trgt)
    c = relay.annotation.compiler_begin(c, trgt) if annotate_non_call_ops else c
    t = relay.Tuple((c, b))
    t = relay.annotation.compiler_end(t, trgt) if annotate_non_call_ops else t
    if get_item:
        t = relay.annotation.compiler_begin(t, trgt) if annotate_non_call_ops else t
        tgi = relay.TupleGetItem(t, 1)
        tgi = relay.annotation.compiler_end(tgi, trgt) if annotate_non_call_ops else tgi
    else:
        tgi = t
    foo = relay.Function([a_], tgi)
    return tvm.IRModule.from_expr(foo)

for get_item in [True, False]:
    for annotate_non_call_ops in [False, True]:
        mod = get_model(get_item)
        mod = transform.AnnotateTarget("clip", annotate_non_call_ops)(mod)
        expected = transform.InferType()(get_expected(annotate_non_call_ops, get_item))
        assert tvm.ir.structural_equal(expected, mod)

注解目标-其他#

def test_if_else():
    target = "test_if_else"

    @tvm.ir.register_op_attr("equal", "target." + target)
    def relu(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("tanh", "target." + target)
    def tanh(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("sigmoid", "target." + target)
    def sigmoid(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("erf", "target." + target)
    def erf(expr):  # pylint: disable=unused-variable
        return True

    """Test that If-else nodes compiles correctly when surrounded by supported nodes."""

    def before():
        data = relay.var("data", shape=(1, 32))
        eq1 = relay.var("e1", shape=[], dtype="float32")
        eq2 = relay.var("e2", shape=[], dtype="float32")
        eq = relay.equal(eq1, eq2)

        true_branch = relay.tanh(data)
        false_branch = relay.sigmoid(data)
        ife = relay.If(eq, true_branch, false_branch)
        out = relay.erf(ife)
        func = relay.Function([data, eq1, eq2], out)
        mod = tvm.IRModule.from_expr(func)

        return mod

    def after():

        data = relay.var("data", shape=(1, 32))
        eq1 = relay.var("e1", shape=[], dtype="float32")
        eq2 = relay.var("e2", shape=[], dtype="float32")

        cb_1 = relay.annotation.compiler_begin(eq1, target)
        cb_2 = relay.annotation.compiler_begin(eq2, target)

        equality_condition = relay.equal(cb_1, cb_2)
        ce_1 = relay.annotation.compiler_end(equality_condition, target)

        # if condition
        cb_3 = relay.annotation.compiler_begin(data, target)
        true_branch = relay.tanh(cb_3)
        ce_2 = relay.annotation.compiler_end(true_branch, target)

        # else condition
        cb_4 = relay.annotation.compiler_begin(data, target)
        false_branch = relay.sigmoid(cb_4)
        ce_3 = relay.annotation.compiler_end(false_branch, target)

        if_condition = relay.If(ce_1, ce_2, ce_3)
        cb_5 = relay.annotation.compiler_begin(if_condition, target)
        erf_out = relay.erf(cb_5)
        ce_4 = relay.annotation.compiler_end(erf_out, target)
        func = relay.Function([data, eq1, eq2], ce_4)
        mod = tvm.IRModule.from_expr(func)
        return mod

    expected = transform.InferType()(after())
    for annotate_non_call_ops in [True, False]:
        result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
        assert tvm.ir.structural_equal(expected, result)


def test_while_let():
    target = "test_while_let"

    @tvm.ir.register_op_attr("less", "target." + target)
    def less(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("add", "target." + target)
    def add(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("zeros_like", "target." + target)
    def zeros_like(expr):  # pylint: disable=unused-variable
        return True

    """Test that let nodes compiles correctly when surrounded by other nodes."""

    def before():

        var1 = relay.var("var1", shape=(2,))
        var2 = relay.var("var2", shape=(), dtype="int32")
        var3 = relay.var("var3", shape=(2,))
        cond = relay.less(var2, relay.const(10, dtype="int32"))

        loop = relay.var("while_loop")
        ii = var2 + relay.const(1, dtype="int32")
        ss = var3 + var1
        true_branch = loop(ii, ss)
        ife = relay.If(cond, true_branch, var3)
        func_1 = relay.Function([var2, var3], ife)

        ret = relay.Let(loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1)))
        func_2 = relay.Function([var1], ret)
        mod = tvm.IRModule.from_expr(func_2)
        return mod

    def after(annotate_non_call_ops):
        var1 = relay.var("var1", shape=(2,))
        var2 = relay.var("var2", shape=(), dtype="int32")
        var3 = relay.var("var3", shape=(2,))
        var4 = relay.const(10, dtype="int32")

        cb_1 = relay.annotation.compiler_begin(var2, target)
        cb_2 = relay.annotation.compiler_begin(var4, target)

        less_condition = relay.less(cb_1, cb_2)
        ce_1 = relay.annotation.compiler_end(less_condition, target)

        loop = relay.var("while_loop")

        # if condition
        cb_3 = relay.annotation.compiler_begin(var2, target)
        cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target)
        add_op_1 = relay.add(cb_3, cb_4)
        ce_2 = relay.annotation.compiler_end(add_op_1, target)

        cb_5 = relay.annotation.compiler_begin(ce_2, "default") if annotate_non_call_ops else ce_2

        cb_6 = relay.annotation.compiler_begin(var3, target)
        cb_7 = relay.annotation.compiler_begin(var1, target)
        add_op_2 = relay.add(cb_6, cb_7)
        ce_3 = relay.annotation.compiler_end(add_op_2, target)

        cb_8 = relay.annotation.compiler_begin(ce_3, "default") if annotate_non_call_ops else ce_3

        true_branch = loop(cb_5, cb_8)  # while loop
        ce_4 = (
            relay.annotation.compiler_end(true_branch, "default")
            if annotate_non_call_ops
            else true_branch
        )
        if_condition = relay.If(ce_1, ce_4, var3)
        const_1 = relay.const(0, dtype="int32")
        cb_9 = (
            relay.annotation.compiler_begin(const_1, "default")
            if annotate_non_call_ops
            else const_1
        )
        cb_10 = relay.annotation.compiler_begin(var1, target)
        zeros_like = relay.zeros_like(cb_10)
        ce_5 = relay.annotation.compiler_end(zeros_like, target)
        cb_11 = relay.annotation.compiler_begin(ce_5, "default") if annotate_non_call_ops else ce_5
        while_condition = loop(cb_9, cb_11)
        ce_6 = (
            relay.annotation.compiler_end(while_condition, "default")
            if annotate_non_call_ops
            else while_condition
        )

        func_1 = relay.Function([var2, var3], if_condition)
        ret = relay.Let(loop, func_1, ce_6)
        func_2 = relay.Function([var1], ret)
        mod = tvm.IRModule.from_expr(func_2)
        return mod

    for annotate_non_call_ops in [False, True]:
        result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
        expected = transform.InferType()(after(annotate_non_call_ops))
        assert tvm.ir.structural_equal(expected, result)


def test_if_free_vars():
    target = "test_if_free_vars"

    @tvm.ir.register_op_attr("equal", "target." + target)
    def equal(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("sigmoid", "target." + target)
    def sigmoid(expr):  # pylint: disable=unused-variable
        return True

    @tvm.ir.register_op_attr("erf", "target." + target)
    def erf(expr):  # pylint: disable=unused-variable
        return True

    """Test that If-else nodes compiles correctly when surrounded by free variables"""

    def before():
        data = relay.var("data", shape=(1, 32))
        eq1 = relay.var("e1", shape=[], dtype="float32")
        eq2 = relay.var("e2", shape=[], dtype="float32")
        eq = relay.equal(eq1, eq2)

        true_branch = relay.zeros(shape=(1, 32), dtype="float32")
        false_branch = relay.sigmoid(data)
        ife = relay.If(eq, true_branch, false_branch)
        out = relay.erf(ife)

        func = relay.Function([data, eq1, eq2], out)
        mod = tvm.IRModule.from_expr(func)

        return mod

    def after():
        data = relay.var("data", shape=(1, 32))
        eq1 = relay.var("e1", shape=[], dtype="float32")
        eq2 = relay.var("e2", shape=[], dtype="float32")

        cb_1 = relay.annotation.compiler_begin(eq1, target)
        cb_2 = relay.annotation.compiler_begin(eq2, target)

        equality_condition = relay.equal(cb_1, cb_2)
        ce_1 = relay.annotation.compiler_end(equality_condition, target)

        # if condition
        true_branch = relay.zeros(shape=(1, 32), dtype="float32")

        # else condition
        cb_3 = relay.annotation.compiler_begin(data, target)
        false_branch = relay.sigmoid(cb_3)
        ce_2 = relay.annotation.compiler_end(false_branch, target)

        if_condition = relay.If(ce_1, true_branch, ce_2)
        cb_4 = relay.annotation.compiler_begin(if_condition, target)
        erf_out = relay.erf(cb_4)
        ce_3 = relay.annotation.compiler_end(erf_out, target)
        func = relay.Function([data, eq1, eq2], ce_3)
        mod = tvm.IRModule.from_expr(func)
        return mod

    for annotate_non_call_ops in [True, False]:
        result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
        expected = transform.InferType()(after())
        assert tvm.ir.structural_equal(expected, result)


def test_free_vars_zeros():
    target = "test_free_vars_zeros"

    """Test that free variables compile correctly on their own"""

    def before():
        func = relay.Function([], relay.zeros(shape=(0), dtype="float32"))
        mod = tvm.IRModule.from_expr(func)
        return mod

    def after():
        func = relay.Function([], relay.zeros(shape=(0), dtype="float32"))
        mod = tvm.IRModule.from_expr(func)
        return mod

    result = transform.AnnotateTarget(target)(before())
    expected = transform.InferType()(after())
    assert tvm.ir.structural_equal(expected, result)


def test_empty_tuple():
    target = "test_empty_tuple"

    """An empty tuple should behave just like a call with no args (see above test)."""

    def before():
        func = relay.Function([], relay.Tuple([]))
        mod = tvm.IRModule.from_expr(func)
        return mod

    def after():
        func = relay.Function([], relay.Tuple([]))
        mod = tvm.IRModule.from_expr(func)
        return mod

    for annotate_non_call_ops in [True, False]:
        result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
        expected = transform.InferType()(after())
        assert tvm.ir.structural_equal(expected, result)