测试 Pass#

split args#

from tvm.ir.transform import Pass
from tvm.ir import IRModule
from tvm.relay import transform
from tvm import relay

def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, Pass)
    mod = IRModule.from_expr(expr)
    mod = relay.transform.InferType()(mod)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body
target = tvm.target.Target("metal")

shape = (1, 1, 1, 3)
dtype = "float32"
axis = 1
inputs = []
for i in range(100):
    inputs.append(relay.var(f"p{i}", shape=shape, dtype=dtype))

def before():
    inp = relay.Tuple(inputs)
    return relay.op.concatenate(inp, axis)

res = run_opt_pass(before(), transform.SplitArgs(target.max_function_args))
limit = target.max_function_args - 1  # one buffer with output
splitNum = int(len(inputs) / limit)
if len(inputs) % limit > 0:
    splitNum += 1

splitted = []
for i in range(splitNum):
    startIdx = i * limit
    argsCount = min(limit, len(inputs) - startIdx)
    args = []
    for j in range(argsCount):
        args.append(inputs[j + startIdx])
    t = relay.Tuple(args)
    concat = relay.op.concatenate(t, axis)
    splitted.append(relay.annotation.stop_fusion(concat))
inp = relay.Tuple(splitted)
expr = relay.op.concatenate(inp, axis)
tvm.ir.structural_equal(res, expr)

Pass Instrument#

import tvm
from tvm.ir import IRModule
from tvm import relay
from tvm.relay import op
from tvm import transform
from tvm.ir.instrument import PassTimingInstrument, pass_instrument
from tvm.ir.transform import PassContext

def get_test_model():
    x, y, z = [relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
    e1 = op.add(x, y)
    e2 = op.subtract(x, z)
    e3 = op.multiply(e1, e1 / e2)
    return IRModule.from_expr(e3 + e2)
pass_timing = PassTimingInstrument()

seq = transform.Sequential([relay.transform.AnnotateSpans(),
                            relay.transform.ToANormalForm(),
                            relay.transform.InferType()])

# 覆盖当前 PassContext 的 instruments
PassContext.current().override_instruments([pass_timing])

mod = get_test_model()
mod = seq(mod)
profiles = pass_timing.render()
assert "AnnotateSpans" in profiles
assert "ToANormalForm" in profiles
assert "InferType" in profiles
# 重置当前 PassContext 的 instruments 为 None
PassContext.current().override_instruments(None)
mod = get_test_model()
mod = seq(mod)
profiles = pass_timing.render()
assert profiles == ""
instrument_definition_type = tvm.testing.parameter("decorator", "subclass")

def test_custom_instrument(instrument_definition_type):
    class BaseTest:
        def __init__(self):
            self.events = []

        def enter_pass_ctx(self):
            self.events.append("enter ctx")

        def exit_pass_ctx(self):
            self.events.append("exit ctx")

        def run_before_pass(self, mod, info):
            self.events.append("run before " + info.name)

        def run_after_pass(self, mod, info):
            self.events.append("run after " + info.name)

    if instrument_definition_type == "decorator":
        MyTest = pass_instrument(BaseTest)

    elif instrument_definition_type == "subclass":
        class MyTest(BaseTest, tvm.ir.instrument.PassInstrument):
            def __init__(self):
                BaseTest.__init__(self)
                tvm.ir.instrument.PassInstrument.__init__(self)

    mod = get_test_model()
    my_test = MyTest()
    with tvm.transform.PassContext(instruments=[my_test]):
        mod = tvm.relay.transform.InferType()(mod)

    assert (
        "enter ctx"
        "run before InferType"
        "run after InferType"
        "exit ctx" == "".join(my_test.events)
    )

禁用 pass#

@pass_instrument
class CustomPI:
    def __init__(self):
        self.events = []

    def should_run(self, mod, info):
        # Only run pass name contains "InferType"
        if "InferType" not in info.name:
            return False
        return True

    def run_before_pass(self, mod, info):
        self.events.append(info.name)


mod = get_test_model()
custom_pi = CustomPI()
# seq = transform.Sequential([relay.transform.AnnotateSpans(),
#                             relay.transform.ToANormalForm(),
#                             relay.transform.InferType()])

with PassContext(instruments=[custom_pi]):
    # mod = seq(mod)
    mod = tvm.relay.transform.AnnotateSpans()(mod)
    mod = tvm.relay.transform.ToANormalForm()(mod)
    mod = tvm.relay.transform.InferType()(mod)

assert "InferType" == "".join(custom_pi.events)
@pass_instrument
class SkipPass:
    def __init__(self, skip_pass_name):
        self.skip_pass_name = skip_pass_name

    def should_run(self, mod, info):
        if self.skip_pass_name in info.name:
            return False
        return True

skip_annotate = SkipPass("AnnotateSpans")
skip_anf = SkipPass("ToANormalForm")

@pass_instrument
class PrintPassName:
    def __init__(self):
        self.events = []

    def run_before_pass(self, mod, info):
        self.events.append(info.name)

mod = get_test_model()
print_pass_name = PrintPassName()
with tvm.transform.PassContext(instruments=[skip_annotate, skip_anf, print_pass_name]):
    mod = tvm.relay.transform.AnnotateSpans()(mod)
    mod = tvm.relay.transform.ToANormalForm()(mod)
    mod = tvm.relay.transform.InferType()(mod)

assert "InferType" == "".join(print_pass_name.events)
@pass_instrument
class PassesCounter:
    def __init__(self):
        self.run_before_count = 0
        self.run_after_count = 0

    def __clear(self):
        self.run_before_count = 0
        self.run_after_count = 0

    def enter_pass_ctx(self):
        self.__clear()

    def exit_pass_ctx(self):
        self.__clear()

    def run_before_pass(self, mod, info):
        self.run_before_count = self.run_before_count + 1

    def run_after_pass(self, mod, info):
        self.run_after_count = self.run_after_count + 1

mod = get_test_model()
passes_counter = PassesCounter()
with tvm.transform.PassContext(instruments=[passes_counter]):
    tvm.relay.build(mod, "llvm")
    assert passes_counter.run_after_count != 0
    assert passes_counter.run_after_count == passes_counter.run_before_count

# Out of pass context scope, should be reset
assert passes_counter.run_before_count == 0
assert passes_counter.run_after_count == 0
configs = PassContext.list_configs()

assert len(configs) > 0
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"]["type"] == "IntImm"
events = []

@pass_instrument
class PI:
    def __init__(self, id):
        self.id = id

    def enter_pass_ctx(self):
        events.append(self.id + " enter_pass_ctx")

    def exit_pass_ctx(self):
        events.append(self.id + " exit_pass_ctx")

    def should_run(self, mod, info):
        events.append("  " + self.id + " should_run")
        return True

    def run_before_pass(self, mod, info):
        events.append("  " + self.id + " run_before_pass")

    def run_after_pass(self, mod, info):
        events.append("  " + self.id + " run_after_pass")

@tvm.transform.module_pass(opt_level=2)
def transform1(mod, ctx):
    events.append("    transform1 pass")
    return mod

@tvm.transform.module_pass(opt_level=2)
def transform2(mod, ctx):
    events.append("    transform2 pass")
    return mod

mod = get_test_model()
with PassContext(instruments=[PI("%1"), PI("%2")]):
    mod = transform1(mod)
    mod = transform2(mod)

assert (
    "%1 enter_pass_ctx"
    "%2 enter_pass_ctx"
    "  %1 should_run"
    "  %2 should_run"
    "  %1 run_before_pass"
    "  %2 run_before_pass"
    "    transform1 pass"
    "  %1 run_after_pass"
    "  %2 run_after_pass"
    "  %1 should_run"
    "  %2 should_run"
    "  %1 run_before_pass"
    "  %2 run_before_pass"
    "    transform2 pass"
    "  %1 run_after_pass"
    "  %2 run_after_pass"
    "%1 exit_pass_ctx"
    "%2 exit_pass_ctx" == "".join(events)
)

Pass 去函数化#

from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay import transform, ExprVisitor, TypeVisitor
from tvm.relay.testing import Prelude
def has_func_type(t):
    """确定类型 t 是 FuncType 还是嵌套的 FuncType"""
    class FuncTypeVisitor(TypeVisitor):
        def __init__(self):
            super().__init__()
            self.has_func = False

        def visit_func_type(self, ftt):
            self.has_func = True

    ftvisitor = FuncTypeVisitor()
    ftvisitor.visit(t)
    return ftvisitor.has_func

确定程序是否有高阶函数,高阶函数定义为:

  • 具有函数类型参数

  • 返回函数

def assert_no_higher_order_functions(expr, mod):
    class CheckFirstOrderVisitor(ExprVisitor):
        def __init__(self, mod):
            super().__init__()
            self.mod = mod
            self.hof = []
            self.visited_gv = set()

        def visit_call(self, call):
            is_higher_order = False
            # check return type
            if has_func_type(call.checked_type):
                is_higher_order = True
            # check argument types
            for a in call.args:
                if has_func_type(a.checked_type):
                    is_higher_order = True
            # if it is higher order, save it for debugging later
            if is_higher_order:
                self.hof.append(call)
            super().visit_call(call)

        def visit_global_var(self, gv):
            # visit global vars to visit entire program
            if gv not in self.visited_gv:
                self.visited_gv.add(gv)
                self.visit(self.mod[gv])

    mod = transform.InferType()(mod)
    check_fo_visitor = CheckFirstOrderVisitor(mod)
    check_fo_visitor.visit(expr)

    nl = "\n--------\n"
    errmsg = f"""found {len(check_fo_visitor.hof)} higher order functions:
  {nl.join(expr.astext() for expr in check_fo_visitor.hof)}"""

    assert len(check_fo_visitor.hof) == 0, errmsg

断言程序是去函数化的,并返回去函数化的模块,假设程序从 mod['main'] 开始:

def defunctionalized(mod):
    mod = transform.InferType()(mod)
    mod["main"] = transform.Defunctionalization(mod["main"], mod)
    mod = transform.InferType()(mod)
    assert_no_higher_order_functions(mod["main"], mod)
    return mod
# adt list to python list
def to_list(mod, l):
    list = mod.get_global_type_var("List")
    list_adt = mod[list]
    cons = list_adt.constructors[0]
    nil = list_adt.constructors[1]

    assert isinstance(l, ConstructorValue)
    val = l
    ret = []
    while True:
        if val.tag == cons.tag:
            ret.append(val.fields[0].numpy())
            val = val.fields[1]
        else:
            assert val.tag == nil.tag
            break
    return ret


# list to adt list
def to_adt_list(mod, arr):
    expr = mod["main"]
    l = mod.get_global_type_var("List")
    list_adt = mod[l]
    cons = list_adt.constructors[0]
    nil = list_adt.constructors[1]

    li = nil()
    for a in arr:
        li = cons(relay.const(a), li)
    adt = relay.create_executor(mod=mod).evaluate(li)
    mod["main"] = expr
    return adt
import tvm
from tvm import relay
import numpy as np

code = """
#[version = "0.0.5"]
def @simple[A, B](%f: fn(A) -> B, %xs: A) -> B {
  %f(%xs)
}
def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] {
  %0 = fn[A](%x: A) -> A {
    %x
  };
  @simple(%0, %l)
}
"""
mod = tvm.parser.fromtext(code)
defunc_mod = defunctionalized(mod)

input = np.random.rand(5, 5).astype("float32")
out = relay.create_executor("debug", mod=mod).evaluate()(input)
defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(input)
np.testing.assert_equal(out.numpy(), defunc_out.numpy())
code = """
#[version = "0.0.5"]
type List[A] {
  Cons(A, List[A]),
  Nil,
}
def @id[A](%x: A) -> A {
  %x
}
def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] {
  match (%xs) {
    Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)),
    Nil => Nil,
  }
}
def @main(%l: List[float32]) -> List[float32] {
  @map(@id, %l)
}
"""
mod = tvm.parser.fromtext(code)
defunc_mod = defunctionalized(mod)

input = np.random.rand(10).astype("float32")

out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input))

defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(
    to_adt_list(defunc_mod, input)
)

np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out))
def test_recursive_datatype():
    # CPS will create recursive datatype
    code = """
#[version = "0.0.5"]
type List[A] {
  Cons(A, List[A]),
  Nil,
}
def @sum(%f: fn(int32) -> int32, %k: List[int32]) -> int32 {
  match (%k) {
    Cons(%x, %rest) => %0 = fn(%n) {
      %x + %f(%n)
    };
    @sum(%0, %rest),
    Nil => %f(0),
  }
}
def @id[A](%x: A) -> A {
  %x
}
def @main(%l: List[int32]) -> int32 {
  @sum(@id, %l)
}
"""
    mod = tvm.parser.fromtext(code)
    defunc_mod = defunctionalized(mod)

    input = np.random.randint(1, 100, 10)

    out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input))

    defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(
        to_adt_list(defunc_mod, input)
    )

    tvm.testing.assert_allclose(out.numpy(), defunc_out.numpy())