张量形状

张量形状#

%cd ../..
import set_env
import numpy as np
import tvm
from tvm import te

重构形状#

备注

\(n\)-D 数组在内存中实际上是作为 \(1\)-D 数组列出的,重构形状生成的代码并不重新排列数据序列,以提高效率。

reshape 运算可以抽象为以下数学形式。

对于任意的 \(\mathbf{x}_i = (x_0^{i}, \cdots, x_{n-1}^{i})^T \in \mathbb{R}^n\),有 \(\mathbf{X} = (\mathbf{x}_0, \cdots, \mathbf{x}_{m-1})^T \in \mathbb{R}^{m \times n}\),即

\[\begin{split} \mathbf{X} = \begin{bmatrix} x_0^{0} & x_1^{0} & \cdots & x_{n-1}^{0}\\ x_0^{1} & x_1^{1} & \cdots & x_{n-1}^{1}\\ \vdots & \vdots & \ddots & \vdots \\ x_0^{m-1} & x_1^{m-1} & \cdots & x_{n-1}^{m-1}\\ \end{bmatrix} = (X_{ij})_{m \times n} \end{split}\]

可以倒过来思考:存在 \(\mathbf{y} = (x_0, \cdots, x_{k-1})^T \in \mathbb{R}^{k}\),将其分成 \(m\) 份,便有 \(\{\mathbf{x}_i\}_0^{m-1}\),这样,\(\mathbf{y}\)\(\mathbf{X}\) 便建立映射关系:

\[ \mathbf{y}_t = \mathbf{X}_{ni+j} \]

或者索引表示为 i, j = t//n, t%n

比如,将 \((m, n)\) 矩阵重构为 \((mn,)\)

n = te.var('n')
m = te.var('m')
A = te.placeholder((m, n), name='A')
B = te.compute((m*n,), lambda i: A[i//n, i%n], 'B')
te_func = te.create_prim_func([A, B])
mod = tvm.build(te_func, target="llvm")
a_np = np.arange(12, dtype='float32').reshape((3, 4))
b_np = a_np.reshape(-1) # 基准结果
a_np, b_np
(array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]], dtype=float32),
 array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.],
       dtype=float32))
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.empty(b_np.shape)
mod(a_nd, b_nd)
b_nd
<tvm.nd.NDArray shape=(12,), cpu(0)>
array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.],
      dtype=float32)

也可以实现一般的二维重构函数。

p, q = te.var('p'), te.var('q')
B = te.compute((p, q), lambda i, j: A[(i*q+j)//n, (i*q+j)%n], name='B')
te_func = te.create_prim_func([A, B])
rt_lib = tvm.build(te_func, target="llvm")
te_func.show()
# from tvm.script import tir as T
@T.prim_func
def func(var_A: T.handle, var_B: T.handle):
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    p = T.var("int32")
    q = T.var("int32")
    A = T.match_buffer(var_A, [m, n], dtype="float32")
    B = T.match_buffer(var_B, [p, q], dtype="float32")
    # body
    # with T.block("root")
    for i0, i1 in T.grid(p, q):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0, i1])
            T.reads(A[(i * q + j) // n, (i * q + j) % n])
            T.writes(B[i, j])
            B[i, j] = A[(i * q + j) // n, (i * q + j) % n]
b_np = a_np.reshape(4, 3) # 基准结果
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.empty(b_np.shape, dtype="float32")

rt_lib(a_nd, b_nd)
b_nd
<tvm.nd.NDArray shape=(4, 3), cpu(0)>
array([[ 0.,  1.,  2.],
       [ 3.,  4.,  5.],
       [ 6.,  7.,  8.],
       [ 9., 10., 11.]], dtype=float32)

警告

在测试结果时,应该意识到,没有对输出形状施加约束,它可以有任意形状 (p, q),因此 TVM 将无法检查 \(qp = nm\)。例如,在下面的例子中,创建了 b,其尺寸 (20) 比 a (12) 大,那么 b 中只有前 12 个元素来自 a ,其他的都是未初始化的值。

a_np = np.arange(12, dtype='float32').reshape((3, 4))
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.empty((5, 4), dtype="float32")
rt_lib(a_nd, b_nd)
print(b_nd)
[[0.0000000e+00 1.0000000e+00 2.0000000e+00 3.0000000e+00]
 [4.0000000e+00 5.0000000e+00 6.0000000e+00 7.0000000e+00]
 [8.0000000e+00 9.0000000e+00 1.0000000e+01 1.1000000e+01]
 [2.7418247e-27 3.0614168e-41 1.5834673e-43 0.0000000e+00]
 [2.8213425e-27 3.0614168e-41 9.9344688e+32 4.5815453e-41]]

切片#

考虑特殊的切片算子 a[bi::si, bj::sj],其中 bibjsisj 可以稍后指定。现在需要根据参数计算输出形状。此外,需要在编译模块时将变量 bibjsisj 作为参数传递。

bi, bj = te.var("bi"), te.var("bj")
si, sj = te.var("si"), te.var("sj")
B = te.compute(((m-bi)//si, (n-bj)//sj),
               lambda i, j: A[i*si+bi, j*sj+bj],
               name='B')
te_func = te.create_prim_func([A, B])
te_func.show()
# from tvm.script import tir as T
@T.prim_func
def func(var_A: T.handle, var_B: T.handle):
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bi = T.var("int32")
    bj = T.var("int32")
    m = T.var("int32")
    n = T.var("int32")
    si = T.var("int32")
    sj = T.var("int32")
    A = T.match_buffer(var_A, [m, n], dtype="float32")
    B = T.match_buffer(var_B, [(m - bi) // si, (n - bj) // sj], dtype="float32")
    # body
    # with T.block("root")
    for i0, i1 in T.grid((m - bi) // si, (n - bj) // sj):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0, i1])
            T.reads(A[i * si + bi, j * sj + bj])
            T.writes(B[i, j])
            B[i, j] = A[i * si + bi, j * sj + bj]
sch = te.create_schedule(B.op)
mod = tvm.build(sch, [A, B, bi, si, bj, sj])
b_nd = tvm.nd.array(np.empty((1, 3), dtype='float32'))
mod(a_nd, b_nd, 1, 2, 1, 1)
np.testing.assert_equal(b_nd.numpy(), a_nd.numpy()[1::2, 1::1])

b_nd = tvm.nd.array(np.empty((1, 2), dtype='float32'))
mod(a_nd, b_nd, 2, 1, 0, 2)
np.testing.assert_equal(b_nd.numpy(), a_nd.numpy()[2::1, 0::2])