PackedFunc#

参考:运行时系统

Function 在 TVM 中起着沟通 frontend 和 backend 的关键作用。Function 提供了类型擦除接口(type-erased interface),您可以使用位置参数回调函数。

  • 编译后的模块返回 Function

  • TVM 后端还将其 API 注册并暴露为 Function

PackedFunc 常见使用场景:

  • 自动暴露 C++ API 到 Python。

  • 从 Python 端调用 PackedFunc。

  • 在生成代码(generated code)中回调 Python 回调来检查结果。

  • 将 Python 钩子(hook)引入 C++ 后端。

全局函数#

  • tvm.register_func() 用于注册全局函数。

下面的代码将 my_packed_func 注册为全局函数。

import tvm

targs = (10, 10.0, "hello")
@tvm.register_func
def my_packed_func(*args):
    assert(tuple(args) == targs)
    return 10
  • tvm.get_global_func():获取全局函数。

注意,这里只是从全局函数表中返回它,然后从 Python 端回调它。

from tvm.runtime.packed_func import PackedFunc

f = tvm.get_global_func("my_packed_func")
assert isinstance(f, PackedFunc)
y = f(*targs)
assert y == 10

但是,也可以从 C++ 后端或在编译后的 TVM 代码中回调相同的函数。

Python 调用 C++ 接口#

使用 C++ 定义加法运算,并提供 Makefile:

#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
using namespace tvm::runtime;

void MyAdd(TVMArgs args, TVMRetValue* rv) {
  // 自动将参数转换为所需的类型。
  int a = args[0];
  int b = args[1];
  // 自动分配返回值 rv
  *rv = a + b;
}

// 注册全局 packed function
TVM_REGISTER_GLOBAL("myadd").set_body(MyAdd);
# Minimum Makefile for the extension package
TVM_ROOT=$(shell cd TVM路径; pwd)
PKG_CFLAGS = -std=c++17 -O2 -fPIC\
	-I${TVM_ROOT}/include\
	-I${TVM_ROOT}/3rdparty/dmlc-core/include\
	-I${TVM_ROOT}/3rdparty/dlpack/include\
	-DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\>


PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s)

ifeq ($(UNAME_S), Darwin)
	PKG_LDFLAGS += -undefined dynamic_lookup
endif

lib/libtvm_ext.so: src/tvm_ext.cc
	@mkdir -p $(@D)
	$(CXX) $(PKG_CFLAGS) -shared -o $@ $^ $(PKG_LDFLAGS)

执行 make,输出动态库到 lib/libtvm_ext.so,接着,Python 代码添加如下内容,便可直接调用 C++ 接口:

from pathlib import Path
import ctypes

def load_lib():
    """加载库,函数将被注册到 TVM"""
    curr_dir = Path("tests").resolve()
    # 作为全局加载,这样全局 extern symbol 对其他 dll 是可见的。
    curr_path = str(curr_dir/"lib/libtvm_ext.so")
    lib = ctypes.CDLL(curr_path, ctypes.RTLD_GLOBAL)
    return lib


_LIB = load_lib()

myadd = tvm.get_global_func("myadd")
myadd(4, 5)
9

PackedFunc 也可以作为参数传递。

在 C++ 端定义:

TVM_REGISTER_GLOBAL("callhello")
.set_body([](TVMArgs args, TVMRetValue* rv) {
  PackedFunc f = args[0];
  f("hello world");
});

python 端可以:

def f(msg):
  print(msg)


callhello = tvm.get_global_func("callhello")
callhello(f)
hello world

C++ 调用 Python API#

convert 函数#

convert() 将给定的 value 转换为 TVM 对象。

比如,列表:

a = tvm.runtime.convert([1, 2, 3])
assert len(a) == 3
assert a[-1].value == 3
a_slice = a[-3:-1]
assert (a_slice[0].value, a_slice[1].value) == (1, 2)
type(a)
tvm.ir.container.Array

可以序列化为 JSON:

json_str = tvm.ir.save_json(a)
# 加载
a_loaded = tvm.ir.load_json(json_str)
type(json_str)
tvm.ir.assert_structural_equal(a_loaded, a, map_free_vars=True)

字典:

amap = tvm.runtime.convert({"a": 2, "b": 3})
type(amap)
tvm.ir.container.Map

其他:

x = tvm.nd.array([1, 2, 3])
arr = tvm.runtime.convert([x, x])
arr
[runtime.NDArray(0x42f0470), runtime.NDArray(0x42f0470)]
tvm.runtime.convert(f)
<tvm.runtime.packed_func.PackedFunc at 0x7f389a84e280>

Hook Python 函数作为 Extern#

下面的例子注册了 python 函数到 TVM 运行时系统,并使用它来完成计算的一个阶段。这使得 TVM 更加灵活。例如,可以插入前端回调来检查中间结果,或者将定制代码与 TVM 混合使用。

import numpy as np
from tvm import te

@tvm.register_func("tvm.contrib.my_tvm_addone")
def my_tvm_addone(x, y):
    print(f"my_tvm_addone signatures: {type(x)}, {type(y)}")
    tvm.nd.array(x.numpy() + 1).copyto(y)


n = 10
dev = tvm.cpu(0)
A = te.placeholder((n,), name="A")
B = te.extern(
    A.shape,
    [A],
    lambda ins, outs: tvm.tir.call_packed("tvm.contrib.my_tvm_addone",
                                          ins[0], outs[0]),
    name="C",
)
te_func = te.create_prim_func([A, B])
te_func.show()
f = tvm.build(te_func, "llvm")
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), dev)
f(a, b)
np.testing.assert_allclose(b.numpy(), a.numpy() + 1, rtol=1e-5)
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import tir as T

@T.prim_func
def main(var_A: T.handle, var_C: T.handle):
    T.func_attr({"tir.noalias": T.bool(True)})
    A = T.match_buffer(var_A, (10,), offset_factor=1)
    C = T.match_buffer(var_C, (10,), offset_factor=1)
    with T.block("C"):
        T.reads(A[0:10])
        T.writes(C[0:10])
        T.call_packed("tvm.contrib.my_tvm_addone", T.tvm_stack_make_array(A.data, T.tvm_stack_make_shape(10), 0, 1, T.float32(0), A.elem_offset), T.tvm_stack_make_array(C.data, T.tvm_stack_make_shape(10), 0, 1, T.float32(0), C.elem_offset))
my_tvm_addone signatures: <class 'tvm.runtime.ndarray.NDArray'>, <class 'tvm.runtime.ndarray.NDArray'>

PyTorch 调用 TVM 接口#

参考:在环境中集成现有运行库

DLPack 数据:

import torch
import torch.utils.dlpack
from tvm.contrib.dlpack import to_pytorch_func

a = np.random.randn(1337)
tvm_a = tvm.nd.array(a)
np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).numpy(), a)
x = torch.rand(56, 56)
tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x))
y = tvm.nd.from_dlpack(tvm_x)
np.testing.assert_equal(x.numpy(), tvm_x.numpy())
np.testing.assert_equal(y.numpy(), tvm_x.numpy())
np.testing.assert_equal(
    torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.numpy()
)
def tvm_func(n):
    XX = te.placeholder((n, n), name="X")
    YY = te.placeholder((n, n), name="Y")
    k = te.reduce_axis((0, n), name="k")
    ZZ = te.compute((n, n), lambda i, j: te.sum(XX[i, k] * YY[k, j], axis=k))
    return te.create_prim_func([XX, YY, ZZ])

te_func = tvm_func(tvm.runtime.convert(137))
te_func.show()
f = tvm.build(te_func, name="f")
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import tir as T

@T.prim_func
def main(X: T.Buffer((137, 137), "float32"), Y: T.Buffer((137, 137), "float32"), compute: T.Buffer((137, 137), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i, j, k in T.grid(137, 137, 137):
        with T.block("compute"):
            v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
            T.reads(X[v_i, v_k], Y[v_k, v_j])
            T.writes(compute[v_i, v_j])
            with T.init():
                compute[v_i, v_j] = T.float32(0)
            compute[v_i, v_j] = compute[v_i, v_j] + X[v_i, v_k] * Y[v_k, v_j]
xx = torch.rand(137, 137)
yy = torch.rand(137, 137)
zz = xx.mm(yy)
zz2 = torch.empty(137, 137)
f_pytorch = to_pytorch_func(f)
f_pytorch(xx, yy, zz2)
np.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-4, atol=1e-4)