条件表达式

条件表达式#

%cd ../..
import set_env
import numpy as np
import tvm
from tvm import te

在 numpy 中使用 where() 处理数组的条件表达式:

a = np.arange(10)
np.where(a < 5, a, 10*a)
array([ 0,  1,  2,  3,  4, 50, 60, 70, 80, 90])
a = np.array([[0, 1, 2],
              [0, 2, 4],
              [0, 3, 6]])
np.where(a < 4, a, -1)  # -1 被广播
array([[ 0,  1,  2],
       [ 0,  2, -1],
       [ 0,  3, -1]])

在 TVM 中使用 if_then_else 实现它。与 where() 类似,它接受三个参数,第一个是条件,如果为真返回第二个参数,否则返回第三个参数。

下面以实现上三角矩阵为例:

n, m = te.var('n'), te.var('m')
A = te.placeholder((m, n))
B = te.compute(A.shape,
               lambda i, j: te.if_then_else(i >= j, A[i, j], 0.0))
te_func = te.create_prim_func([A, B])
te_func.show()
mod = tvm.build(te_func, target="llvm")
# from tvm.script import tir as T
@T.prim_func
def func(var_placeholder: T.handle, var_compute: T.handle):
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    placeholder = T.match_buffer(var_placeholder, [m, n], dtype="float32")
    compute = T.match_buffer(var_compute, [m, n], dtype="float32")
    # body
    # with T.block("root")
    for i0, i1 in T.grid(m, n):
        with T.block("compute"):
            i, j = T.axis.remap("SS", [i0, i1])
            T.reads(placeholder[i, j])
            T.writes(compute[i, j])
            compute[i, j] = T.if_then_else(j <= i, placeholder[i, j], T.float32(0), dtype="float32")
a_np = np.arange(1, 13, dtype='float32').reshape((3, 4))
b_np = np.tril(a_np)
b_np
array([[ 1.,  0.,  0.,  0.],
       [ 5.,  6.,  0.,  0.],
       [ 9., 10., 11.,  0.]], dtype=float32)
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(np.empty_like(a_np))
mod(a_nd, b_nd)
b_nd
<tvm.nd.NDArray shape=(3, 4), cpu(0)>
array([[ 1.,  0.,  0.,  0.],
       [ 5.,  6.,  0.,  0.],
       [ 9., 10., 11.,  0.]], dtype=float32)