TE (Tensor Expression) 实现矩阵乘法

TE (Tensor Expression) 实现矩阵乘法#

用 TE 实现原始程序#

与 TVMScript 的区别是:

  1. TE 抽象的层次更高,使用更简单。

  2. TVMScript 更底层,控制能力更强,但,tiling 等细节,也都需要开发者自己实现。

import tvm
from tvm import te
# data. input & output 内存分配
A = te.placeholder((128, ), name="A")
B = te.placeholder((128, ), name="B")

# 矩阵加法
C = te.compute((128,), lambda i: A[i] + B[i], name="C")

# 根据 tvm 对设计,生成一个 prim func,以便用 TVM 做优化这个 func。
func = te.create_prim_func([A, B, C])
# function name 标记为 main。也是 TVM 的要求,必须有一个 main 函数作为 IRModule 的入口。
func = func.with_attr("global_symbol", "main")
ir_mod_from_te = tvm.IRModule({"main": func})

ir_mod_from_te.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), C: T.Buffer((128,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i in range(128):
            with T.block("C"):
                v_i = T.axis.spatial(128, i)
                T.reads(A[v_i], B[v_i])
                T.writes(C[v_i])
                C[v_i] = A[v_i] + B[v_i]

获取优化前的结果:

import numpy as np
M = 1024
K = 1024
N = 1024

# The default tensor type in tvm
dtype = "float32"

target = "llvm"
dev = tvm.device(target, 0)

# Algorithm
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")

# Default schedule
func = te.create_prim_func([A, B, C])
func = func.with_attr("global_symbol", "main")
ir_module = tvm.IRModule({"main": func})
ir_module.show()

# build and run
func = tvm.build(ir_module, target="llvm")  # The module for CPU backends.

a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)
c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, dev, number=1)
t_baseline = evaluator(a, b, c).mean
print("Baseline: %f" % t_baseline)
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for m, n, k in T.grid(1024, 1024, 1024):
            with T.block("C"):
                v_m, v_n, v_k = T.axis.remap("SSR", [m, n, k])
                T.reads(A[v_m, v_k], B[v_k, v_n])
                T.writes(C[v_m, v_n])
                with T.init():
                    C[v_m, v_n] = T.float32(0)
                C[v_m, v_n] = C[v_m, v_n] + A[v_m, v_k] * B[v_k, v_n]
Baseline: 2.256611

用 TVM scheduler 优化程序。修改后的 loop,对 cache 更友好。实测,性能提升 10x。

sch = tvm.tir.Schedule(ir_module)
block_c = sch.get_block("C")
# Get loops surronding the block
(y, x, k) = sch.get_loops(block_c)
# step 1: tile (split)
block_size = 32
yo, yi = sch.split(y, [None, block_size])
xo, xi = sch.split(x, [None, block_size])

# step 2: reorder
sch.reorder(yo, xo, k, yi, xi)
sch.mod.show()

# build and run
func = tvm.build(sch.mod, target="llvm")  # The module for CPU backends.

c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, dev, number=1)
t_new = evaluator(a, b, c).mean
print("after transformation: %f. baseline: %f, improved: %.2fx" % (
    t_new, t_baseline, t_baseline/t_new))
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for m_0, n_0, k, m_1, n_1 in T.grid(32, 32, 1024, 32, 32):
            with T.block("C"):
                v_m = T.axis.spatial(1024, m_0 * 32 + m_1)
                v_n = T.axis.spatial(1024, n_0 * 32 + n_1)
                v_k = T.axis.reduce(1024, k)
                T.reads(A[v_m, v_k], B[v_k, v_n])
                T.writes(C[v_m, v_n])
                with T.init():
                    C[v_m, v_n] = T.float32(0)
                C[v_m, v_n] = C[v_m, v_n] + A[v_m, v_k] * B[v_k, v_n]
after transformation: 0.246108. baseline: 2.256611, improved: 9.17x