填充运算

填充运算#

%cd ../..
import set_env
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm import te

简单示例#

将使用的例子是用 \(0\) 填充矩阵 \(A\)

a_np = np.arange(1, 13, dtype='float32').reshape((3, 4))
b_np = np.zeros((5, 6), dtype='float32')
b_np[1:-1,1:-1] = a_np
print(b_np)
[[ 0.  0.  0.  0.  0.  0.]
 [ 0.  1.  2.  3.  4.  0.]
 [ 0.  5.  6.  7.  8.  0.]
 [ 0.  9. 10. 11. 12.  0.]
 [ 0.  0.  0.  0.  0.  0.]]
p = 1 # padding size
n, m = te.var('n'), te.var('m')
A = te.placeholder((m, n), name='A')
B = te.compute((m+p*2, n+p*2),
                lambda i, j: te.if_then_else(te.any(i<p, i>=m+p, j<p, j>=n+p), 
                                             0, A[i-p, j-p]),
                name='B')
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")
# from tvm.script import tir as T
@T.prim_func
def func(var_A: T.handle, var_B: T.handle):
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    A = T.match_buffer(var_A, [m, n], dtype="float32")
    B = T.match_buffer(var_B, [m + 2, n + 2], dtype="float32")
    # body
    # with T.block("root")
    for i0, i1 in T.grid(m + 2, n + 2):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0, i1])
            T.reads(A[i - 1, j - 1])
            T.writes(B[i, j])
            B[i, j] = T.if_then_else(i < 1 or m + 1 <= i or j < 1 or n + 1 <= j, T.float32(0), A[i - 1, j - 1], dtype="float32")

验证结果:

a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.empty(b_np.shape)
mod(a_nd, b_nd)
b_nd
<tvm.nd.NDArray shape=(5, 6), cpu(0)>
array([[ 0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  2.,  3.,  4.,  0.],
       [ 0.,  5.,  6.,  7.,  8.,  0.],
       [ 0.,  9., 10., 11., 12.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)

通用 2D 填充#

val = 0
dtype = "float32"
ph, pw = te.var("hpad"), te.var("wpad")
batch_size = te.var("batch_size")
kernel_size = te.var("kernel_size")
height = te.var("height")
width = te.var("width")
shape = batch_size, kernel_size, height, width
pad_shape = batch_size, kernel_size, height+2*ph, width+2*pw
data = te.placeholder(shape, dtype=dtype)
pad_data = te.compute(
            pad_shape,
            lambda *i: te.if_then_else(
                te.any(i[-2]<ph, i[-2]>=height+ph, i[-1]<pw, i[-1]>=width+pw),
                val, data[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
            name='pad_data')
te_func = te.create_prim_func([data, pad_data])
te_func.show()
sch = te.create_schedule(pad_data.op)
mod = tvm.build(sch, [data, pad_data, batch_size, kernel_size, height, width, ph, pw], target="llvm")
# from tvm.script import tir as T
@T.prim_func
def func(var_placeholder: T.handle, var_pad_data: T.handle):
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    batch_size = T.var("int32")
    height = T.var("int32")
    hpad = T.var("int32")
    kernel_size = T.var("int32")
    width = T.var("int32")
    wpad = T.var("int32")
    placeholder = T.match_buffer(var_placeholder, [batch_size, kernel_size, height, width], dtype="float32")
    pad_data = T.match_buffer(var_pad_data, [batch_size, kernel_size, height + 2 * hpad, width + 2 * wpad], dtype="float32")
    # body
    # with T.block("root")
    for i0, i1, i2, i3 in T.grid(batch_size, kernel_size, hpad * 2 + height, wpad * 2 + width):
        with T.block("pad_data"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(placeholder[i0_1, i1_1, i2_1 - hpad, i3_1 - wpad])
            T.writes(pad_data[i0_1, i1_1, i2_1, i3_1])
            pad_data[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 < hpad or height + hpad <= i2_1 or i3_1 < wpad or width + wpad <= i3_1, T.float32(0), placeholder[i0_1, i1_1, i2_1 - hpad, i3_1 - wpad], dtype="float32")
def pad2d(X, ph, pw, val=0, name="pad_data"):
    """Pad X with the given value in 2-D

    ph, pw : height and width padding
    val : padding value, default 0
    """
    assert len(X.shape) >= 2
    nh, nw = X.shape[-2], X.shape[-1]
    return te.compute(
            (*X.shape[0:-2], nh+ph*2, nw+pw*2),
            lambda *i: te.if_then_else(
                te.any(i[-2]<ph, i[-2]>=nh+ph, i[-1]<pw, i[-1]>=nw+pw),
                val, X[i[:-2]+(i[-2]-ph, i[-1]-pw)]),
            name=name)
A = te.placeholder((2, 3, 4), name="data")
B = pad2d(A, 1, 2)
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")
a = tvm.nd.array(np.ones((2, 3, 4), dtype='float32'))
b = tvm.nd.array(np.empty((2, 5, 8), dtype='float32'))
mod(a, b)
print(b)
# from tvm.script import tir as T
@T.prim_func
def func(data: T.Buffer[(2, 3, 4), "float32"], pad_data: T.Buffer[(2, 5, 8), "float32"]):
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    # body
    # with T.block("root")
    for i0, i1, i2 in T.grid(2, 5, 8):
        with T.block("pad_data"):
            i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2])
            T.reads(data[i0_1, i1_1 - 1, i2_1 - 2])
            T.writes(pad_data[i0_1, i1_1, i2_1])
            pad_data[i0_1, i1_1, i2_1] = T.if_then_else(i1_1 < 1 or 4 <= i1_1 or i2_1 < 2 or 6 <= i2_1, T.float32(0), data[i0_1, i1_1 - 1, i2_1 - 2], dtype="float32")
[[[0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 1. 1. 1. 1. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0.]]]