Relay 模块

Relay 模块#

Relay 保留称为 “module” 的全局数据结构(在其他函数式编程语言中通常称为 “environment”),以跟踪全局函数的定义。特别地,该模块保持全局变量到它们所表示的函数表达式的全局可访问映射。模块的实用之处在于,它允许全局函数递归地引用它们自己或任何其他全局函数(例如,在 mutual 递归中)。

from tvm import relay

# 定义变量
names = "xy"
x, y = [relay.var(name) for name in names]
# 定义函数
add_op = x + y
add_func = relay.Function([x, y], add_op)

声明全局变量:

add_gvar = relay.GlobalVar("AddFunc")
print(add_gvar)
@AddFunc

定义将 add_func 提升为全局变量:

from tvm import IRModule

mod = IRModule({add_gvar: add_func})
print(mod)
def @AddFunc(%x, %y) {
  add(%x, %y)
}

获取模块的全局变量内容:

mod[add_gvar]
fn (%x, %y) {
  add(%x, %y)
}

也可以直接借助全局变量的名字获取其内容:

mod["AddFunc"]
fn (%x, %y) {
  add(%x, %y)
}

也可以分配新的全局变量给模块:

names = "xy"
x, y = [relay.var(name) for name in names]
# 定义函数
mul_op = x * y
mul_func = relay.Function([x, y], mul_op)
mod["MulFunc"] = mul_func

print(mod)
def @AddFunc(%x, %y) {
  add(%x, %y)
}

def @MulFunc(%x1, %y1) {
  multiply(%x1, %y1)
}

也可以通过 Python 字典更新全局变量:

names = "xyz"
x, y, z = [relay.var(name) for name in names]
# 定义函数
v1 = x * y
muladd_op = v1 + z
muladd_func = relay.Function([x, y, z], muladd_op)

mod.update({"MulAddFunc": muladd_func})
print(mod)
def @AddFunc(%x, %y) {
  add(%x, %y)
}

def @MulAddFunc(%x1, %y1, %z) {
  %0 = multiply(%x1, %y1);
  add(%0, %z)
}

def @MulFunc(%x2, %y2) {
  multiply(%x2, %y2)
}

查看所有全局变量:

mod.get_global_vars()
[GlobalVar(AddFunc), GlobalVar(MulFunc), GlobalVar(MulAddFunc)]