张量内积#

%cd ../..
import set_env
/media/pc/data/4tb/lxw/home/lxw/tvm-book/doc/tutorials
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm import te, topi

矩阵乘法分块#

\(\mathbf{x} = (x_1, \cdots, x_w)^T, \mathbf{y} = (y_1, \cdots, y_w)^T \in \mathbb{R}^w\),则它们的内积为

\[ \langle \mathbf{x}, \mathbf{y} \rangle = \sum_i^w x_i y_i = x^T \cdot y = x \cdot y^T \in \mathbb{R} \]

进一步有 \(\mathbf{X} = (\mathbf{x}_1, \cdots, \mathbf{x}_h)^T \in \mathbb{R}^{h \times w}\), \(\mathbf{Y} = (\mathbf{y}_1, \cdots, \mathbf{y}_{h_o})^T \in \mathbb{R}^{h_o \times w}\),有

\[ \langle \mathbf{X}, \mathbf{Y} \rangle = \mathbf{X} \cdot \mathbf{Y}^T = (\langle \mathbf{x}_i, \mathbf{y}_j \rangle)_{i=1, j=1}^{i=h, j=h_o} \in \mathbb{R}^{h \times h_o} \]
a_np = np.arange(24).reshape(3, 8)
b_np = np.arange(16).reshape(2, 8)
print(f"a_np:\n{a_np}\nb_np:\n{b_np}")
a_np:
[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]]
b_np:
[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]]

内积参考结果:

c_np = a_np @ b_np.T
c_np
array([[ 140,  364],
       [ 364, 1100],
       [ 588, 1836]])

tvm 数组:

a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty(c_np.shape, dtype=c_np.dtype)
m = te.var("m")
n = te.var("n")
o = te.var("o")
A = te.placeholder((m, n), "int64", "X")
B = te.placeholder((o, n), "int64", "Y")
C = topi.matmul(A, B, transp_b=True) # 矩阵乘法
te_func = te.create_prim_func([A, B, C])
te_func.show()
mod = tvm.build(te_func, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)
# from tvm.script import tir as T
@T.prim_func
def func(var_X: T.handle, var_Y: T.handle, var_T_matmul: T.handle):
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    o = T.var("int32")
    X = T.match_buffer(var_X, [m, n], dtype="int64")
    Y = T.match_buffer(var_Y, [o, n], dtype="int64")
    T_matmul = T.match_buffer(var_T_matmul, [m, o], dtype="int64")
    # body
    # with T.block("root")
    for i0, i1, i2 in T.grid(m, o, n):
        with T.block("T_matmul"):
            ax0, ax1, k = T.axis.remap("SSR", [i0, i1, i2])
            T.reads(X[ax0, k], Y[ax1, k])
            T.writes(T_matmul[ax0, ax1])
            with T.init():
                T_matmul[ax0, ax1] = T.int64(0)
            T_matmul[ax0, ax1] = T_matmul[ax0, ax1] + X[ax0, k] * Y[ax1, k]

三维张量内积#

对于三维张量 \(\mathsf{X} = (\mathbf{X}_1, \cdots, \mathbf{X}_{c_i})^T \in \mathbb{R}^{c_i \times h \times w}\), \(\mathsf{Y} = (\mathbf{Y}_1, \cdots, \mathbf{Y}_{c_o})^T \in \mathbb{R}^{c_o \times {h_o} \times w}\),有

\[ \langle \mathsf{X}, \mathsf{Y} \rangle = \mathsf{X} \cdot \mathsf{Y}^T = (\langle \mathbf{X}_i, \mathbf{Y}_j \rangle)_{i=1, j=1}^{i=c_i, j=c_o} \in \mathbb{R}^{c_i \times h \times h_o \times c_o} \]
mod = IRModule({"mm": te_func})
sch = tvm.tir.Schedule(mod)
block_Z = sch.get_block("T_matmul", func_name="mm")
ax0, ax1, k = sch.get_loops(block_Z)
k0, k1 = sch.split(k, factors=[None, 4])
sch.mod.show()
mod = tvm.build(sch.mod, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def mm(var_X: T.handle, var_Y: T.handle, var_T_matmul: T.handle):
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        n = T.var("int32")
        o = T.var("int32")
        X = T.match_buffer(var_X, [m, n], dtype="int64")
        Y = T.match_buffer(var_Y, [o, n], dtype="int64")
        T_matmul = T.match_buffer(var_T_matmul, [m, o], dtype="int64")
        # body
        # with T.block("root")
        for i0, i1, i2_0, i2_1 in T.grid(m, o, (n + 3) // 4, 4):
            with T.block("T_matmul"):
                T.where(i2_0 * 4 + i2_1 < n)
                ax0, ax1 = T.axis.remap("SS", [i0, i1])
                k = T.axis.reduce(n, i2_0 * 4 + i2_1)
                T.reads(X[ax0, k], Y[ax1, k])
                T.writes(T_matmul[ax0, ax1])
                with T.init():
                    T_matmul[ax0, ax1] = T.int64(0)
                T_matmul[ax0, ax1] = T_matmul[ax0, ax1] + X[ax0, k] * Y[ax1, k]
    

四维张量内积#

对于四维张量 \(\mathop{X} = (\mathsf{X}_1, \cdots, \mathsf{X}_{b_i})^T \in \mathbb{R}^{b_i \times c_i \times h \times w}\), \(\mathop{Y} = (\mathsf{Y}_1, \cdots, \mathsf{Y}_{b_o})^T \in \mathbb{R}^{b_o \times c_o \times {h_o} \times w}\),有

\[ \langle \mathop{X}, \mathop{Y} \rangle = X \cdot Y^T = (\langle \mathsf{X}_i, \mathsf{Y}_j \rangle)_{i=1, j=1}^{i=b_i, j=b_o} \in \mathbb{R}^{b_i \times c_i \times h \times h_o \times c_o \times b_o} \]
mod = IRModule({"mm": te_func})
sch = tvm.tir.Schedule(mod)
block_Z = sch.get_block("T_matmul", func_name="mm")
ax0, ax1, k = sch.get_loops(block_Z)
k0, k1, k2 = sch.split(k, factors=[None, 2, 2])
sch.mod.show()
mod = tvm.build(sch.mod, target="llvm")
mod(a_nd, b_nd, c_nd)
np.testing.assert_equal(c_nd.numpy(), c_np)
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def mm(var_X: T.handle, var_Y: T.handle, var_T_matmul: T.handle):
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        n = T.var("int32")
        o = T.var("int32")
        X = T.match_buffer(var_X, [m, n], dtype="int64")
        Y = T.match_buffer(var_Y, [o, n], dtype="int64")
        T_matmul = T.match_buffer(var_T_matmul, [m, o], dtype="int64")
        # body
        # with T.block("root")
        for i0, i1, i2_0, i2_1, i2_2 in T.grid(m, o, (n + 3) // 4, 2, 2):
            with T.block("T_matmul"):
                T.where((i2_0 * 2 + i2_1) * 2 + i2_2 < n)
                ax0, ax1 = T.axis.remap("SS", [i0, i1])
                k = T.axis.reduce(n, i2_0 * 4 + i2_1 * 2 + i2_2)
                T.reads(X[ax0, k], Y[ax1, k])
                T.writes(T_matmul[ax0, ax1])
                with T.init():
                    T_matmul[ax0, ax1] = T.int64(0)
                T_matmul[ax0, ax1] = T_matmul[ax0, ax1] + X[ax0, k] * Y[ax1, k]
    
@tvm.script.ir_module
class MatmulModule:
    @T.prim_func
    def main(
        A: T.Buffer[(1024, 1024), "float32"],
        B: T.Buffer[(1024, 1024), "float32"],
        C: T.Buffer[(1024, 1024), "float32"],
    ) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(1024, 1024, 1024):
            with T.block("matmul"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.float32(0)
                C[vi, vj] += A[vi, vk] * B[vj, vk]
sch = tvm.tir.Schedule(MatmulModule)
i, j, k = sch.get_loops("matmul")
i, ii = sch.split(i, factors=[None, 16])
j, ji = sch.split(j, factors=[None, 16])
k, ki = sch.split(k, factors=[None, 16])
sch.reorder(i, j, k, ii, ji, ki)
sch.mod.show()
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        # with T.block("root")
        for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(64, 64, 64, 16, 16, 16):
            with T.block("matmul"):
                vi = T.axis.spatial(1024, i_0 * 16 + i_1)
                vj = T.axis.spatial(1024, j_0 * 16 + j_1)
                vk = T.axis.reduce(1024, k_0 * 16 + k_1)
                T.reads(A[vi, vk], B[vj, vk])
                T.writes(C[vi, vj])
                with T.init():
                    C[vi, vj] = T.float32(0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
    
block_mm = sch.blockize(ii)
sch.mod.show()
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        # with T.block("root")
        for i_0, j_0, k_0 in T.grid(64, 64, 64):
            with T.block("matmul_o"):
                vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
                T.reads(A[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], B[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                with T.init():
                    for i_1, j_1 in T.grid(16, 16):
                        with T.block("matmul_init"):
                            vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                            T.reads()
                            T.writes(C[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                            C[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                for i_1, j_1, k_1 in T.grid(16, 16, 16):
                    with T.block("matmul"):
                        vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                        T.reads(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                        T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                        C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
    
A_reg = sch.cache_read(block_mm, 0, storage_scope="global.A_reg")
B_reg = sch.cache_read(block_mm, 1, storage_scope="global.B_reg")
sch.compute_at(A_reg, k)
sch.compute_at(B_reg, k)

write_back_block = sch.cache_write(block_mm, 0, storage_scope="global.accumulator")
sch.reverse_compute_at(write_back_block, j)
sch.mod.show()
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        # with T.block("root")
        A_global_A_reg = T.alloc_buffer([1024, 1024], dtype="float32", scope="global.A_reg")
        B_global_B_reg = T.alloc_buffer([1024, 1024], dtype="float32", scope="global.B_reg")
        C_global_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="global.accumulator")
        for i_0, j_0 in T.grid(64, 64):
            for k_0 in T.serial(64):
                for ax0, ax1 in T.grid(16, 16):
                    with T.block("A_global.A_reg"):
                        v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
                        v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
                        T.reads(A[v0, v1])
                        T.writes(A_global_A_reg[v0, v1])
                        A_global_A_reg[v0, v1] = A[v0, v1]
                for ax0, ax1 in T.grid(16, 16):
                    with T.block("B_global.B_reg"):
                        v0 = T.axis.spatial(1024, j_0 * 16 + ax0)
                        v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
                        T.reads(B[v0, v1])
                        T.writes(B_global_B_reg[v0, v1])
                        B_global_B_reg[v0, v1] = B[v0, v1]
                with T.block("matmul_o"):
                    vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
                    T.reads(A_global_A_reg[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], B_global_B_reg[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                    T.writes(C_global_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                    with T.init():
                        for i_1, j_1 in T.grid(16, 16):
                            with T.block("matmul_init"):
                                vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                                T.reads()
                                T.writes(C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                                C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                    for i_1, j_1, k_1 in T.grid(16, 16, 16):
                        with T.block("matmul"):
                            vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                            T.reads(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                            T.writes(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                            C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
            for ax0, ax1 in T.grid(16, 16):
                with T.block("C_global.accumulator"):
                    v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
                    v1 = T.axis.spatial(1024, j_0 * 16 + ax1)
                    T.reads(C_global_accumulator[v0, v1])
                    T.writes(C[v0, v1])
                    C[v0, v1] = C_global_accumulator[v0, v1]