张量程序抽象

张量程序抽象#

备注

张量 (Tensor) 是表示神经网络模型执行的输入、输出和中间结果的多维数组。

张量函数 (Tensor functions) 是神经网络的“知识”的计算序列集合。

机器学习编译(MLC)的过程可以看作是张量函数的变换过程。

元张量函数(primitive tensor function)是指模型执行中的单个计算单元。MLC 过程可以选择变换元张量函数的实现方式。

张量程序(tensor program)是表示元张量函数的有效抽象。

  • 关键成分包括: 存储数据的多维数组,驱动张量计算的循环嵌套,计算语句。

  • 程序变换可以被用于加速张量程序的执行。

  • 张量程序中额外的结构能够为程序变换提供更多的信息。

TensorIR#

下面着重接受一类张量程序抽象:TensorIRTensorIR 常用来表示循环和相关的硬件加速选择,如多线程、特殊硬件指令的使用和内存访问。

探讨如下数学表示:

\[\begin{split} Y_{i, j} = \sum_k A_{i, k} \times B_{k, j}\\ C_{i, j} = \operatorname{relu}(Y_{i, j}) = \max(Y_{i, j}, 0) \end{split}\]
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
print(f'tvm 版本: {tvm.__version__}')
tvm 版本: 0.14.dev0

使用 NumPy 实现:

dtype = "float32"
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
# a @ b 等价于 np.matmul(a, b)
c_mm_relu = np.maximum(a_np @ b_np, 0)

为了更清晰理解张量程序的变换底层逻辑,约定使用低级 NumPy 实现:

  1. 在必要时使用循环而不是数组函数来展示可能的循环计算。

  2. 如果可能,总是通过 numpy.empty() 显式地分配数组并传递它们。

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

lnumpy_mm_relu() 翻译为 TVMScript:

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
MyModule.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(Y[vi, vj])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Y[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
  1. T.Buffer(输入与输出)与 T.alloc_buffer(中间结果) 定义了缓冲区;

  2. T.grid 用于循环嵌套;

  3. T.block 管理着计算块。

是 TensorIR 中的基本计算单位。

T.block 包含一组块轴(vi、vj、vk)和围绕它们定义的计算。

[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])

T.axis.spatial 标注的信息指示循环变量是独立的,其对应的轴称为 空间轴T.axis.reduce 对应的轴称为 归约轴

可以使用 T.axis.remap 简写块轴语法,比如:

vi, vj, vk = T.axis.remap("SSR", [i, j, k])
vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce(range_of_k, k)

备注

  1. tir.noalias 表示所有的缓冲存储器不重叠。

  2. @tvm.script.ir_module 表示 MyModuleIRModuleIRModule 是在机器学习编译中保存张量函数集合的容器对象。

type(MyModule)
tvm.ir.module.IRModule
type(MyModule["mm_relu"])
tvm.tir.function.PrimFunc

变换#

lnumpy_mm_relu 可以改写为如下实现:

def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j0 in range(32):
            for k in range(128):
                for j1 in range(4):
                    j = j0 * 4 + j1
                    if k == 0:
                        Y[i, j] = 0
                    Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu_v2(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

TVMScript 无需重写代码,只需要:

sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4])
sch.reorder(j0, k, j1) # 改变循环的顺序
sch.mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0, k, j_1 in T.grid(128, 32, 128, 4):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j_0 * 4 + j_1)
                vk = T.axis.reduce(128, k)
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(Y[vi, vj])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Y[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

reverse_compute_at 的原语将块 C 移动到 Y 的内循环里:

block_C = sch.get_block("C", "mm_relu")
sch.reverse_compute_at(block_C, j0)
sch.mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0 in T.grid(128, 32):
            for k, j_1 in T.grid(128, 4):
                with T.block("Y"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    with T.init():
                        Y[vi, vj] = T.float32(0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in range(4):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + ax0)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

在循环变换之后,我们可以将 Y 元素的初始化与归约更新分开。我们可以通过 decompose_reduction 原语来做到这一点。(注意:这也是 TVM 在以后编译的时候隐式做的,所以这一步的主要目的是让它显式,看看最终效果)。

block_Y = sch.get_block("Y", "mm_relu")
sch.decompose_reduction(block_Y, k)
sch.mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0 in T.grid(128, 32):
            for j_1_init in range(4):
                with T.block("Y_init"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + j_1_init)
                    T.reads()
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = T.float32(0)
            for k, j_1 in T.grid(128, 4):
                with T.block("Y_update"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in range(4):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 4 + ax0)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

构建与运行#

调用构建函数将 IRModule 变换为 runtime.Module (表示可运行函数的集合)。 这里 target 指定了部署环境的详细信息。对于现在这种特殊情况,我们将使用 llvm,它可以帮助我们编译到本机 CPU 平台。

rt_lib = tvm.build(MyModule, target="llvm")

创建三个用于保存输入和输出的 TVM NDArray:

a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype="float32")
type(c_nd)
tvm.runtime.ndarray.NDArray

构建变换后的程序:

rt_lib_after = tvm.build(sch.mod, target="llvm")
rt_lib_after["mm_relu"](a_nd, b_nd, c_nd)
np.testing.assert_allclose(c_mm_relu, c_nd.numpy(), rtol=1e-5)

比较两者的时间差:

f_timer_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of MyModule %g sec" % f_timer_before(a_nd, b_nd, c_nd).mean)
f_timer_after = rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print("Time cost of transformed sch.mod %g sec" % f_timer_after(a_nd, b_nd, c_nd).mean)
Time cost of MyModule 0.00184356 sec
Time cost of transformed sch.mod 0.000445022 sec