Relay 模块级 Pass

Relay 模块级 Pass#

模块级 Pass module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)

当提供 pass_func 时,此函数返回回调函数。否则,它将充当装饰器函数。

import numpy as np
import tvm
from tvm.ir.transform import module_pass
from tvm import relay
from tvm.relay.testing import run_infer_type

构建 pass:

@module_pass(opt_level=2)
def transform(mod, ctx):
    new_mod = tvm.IRModule()
    x = relay.var("x", shape=(5, 10), dtype="float32")
    new_mod["abs"] = relay.Function([x], relay.abs(x))
    new_mod.update(mod)
    return new_mod
type(transform)
tvm.ir.transform.ModulePass

可以打印此变换的基本信息:

transform
Run Module pass: transform at the optimization level 2
transform.info
The meta data of the pass - pass name: transform, opt_level: 2, required passes: []
transform.pass_info
The meta data of the pass - pass name: transform, opt_level: 2, required passes: []
transform.handle
c_void_p(94634432288232)

这里的 transform 函数向输入模块添加了 abs 函数,但它也可以是模块级的任何定制优化。创建这个 module_pass 之后,用户可以将它应用到任意 Relay 模块上。例如,可以构建空模块,并应用此传递来添加 abs 函数。

mod = tvm.IRModule()
mod = transform(mod)
mod
#[version = "0.0.5"]
def @abs(%x: Tensor[(5, 10), float32]) {
  abs(%x)
}

module_pass 作为类装饰器#

pass_func 也可以是带有 transform_module 方法的类类型。这个函数将使用 transform_module 作为 pass 函数来创建装饰过的 ModulePass

@module_pass(opt_level=1)
class TestPipeline:
    """简单的测试函数,将一个参数替换为另一个参数。"""
    def __init__(self, new_mod, replace):
        self.new_mod = new_mod
        self.replace = replace

    def transform_module(self, mod, ctx):
        if self.replace:
            return self.new_mod
        return mod

创建定制管道的实例:

x = relay.var("x", shape=(10, 20))
m1 = tvm.IRModule.from_expr(relay.Function([x], x))
m2 = tvm.IRModule.from_expr(relay.Function([x], relay.log(x)))
fpass = TestPipeline(m2, replace=True)
assert fpass.info.name == "TestPipeline"
mod3 = fpass(m1)
assert mod3.same_as(m2)
mod4 = TestPipeline(m2, replace=False)(m1)
assert mod4.same_as(m1)
mod3
#[version = "0.0.5"]
def @main(%x: Tensor[(10, 20), float32]) {
  log(%x)
}
mod4
#[version = "0.0.5"]
def @main(%x: Tensor[(10, 20), float32]) {
  %x
}