tvm.te.tag.tag_scope()

tvm.te.tag.tag_scope()#

tag_scope() 函数,它接受字符串参数 tag 并返回类型为 TagScope 的对象。TagScope 类用于创建具有特定标签的算子的作用域,使它们能够轻松地被识别和管理。

import set_env
import tvm
from tvm import te

with 管理器的形式构建:

n = te.var('n')
m = te.var('m')
l = te.var('l')
A = te.placeholder((n, l), name='A')
B = te.placeholder((m, l), name='B')
k = te.reduce_axis((0, l), name='k')
with tvm.te.tag_scope(tag='matmul'):
    C = te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k))
te.create_prim_func([A, B, C]).show()
# from tvm.script import tir as T

@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_compute: T.handle):
    T.func_attr({"tir.noalias": T.bool(True)})
    n, l = T.int32(), T.int32()
    A = T.match_buffer(var_A, (n, l))
    m = T.int32()
    B = T.match_buffer(var_B, (m, l))
    compute = T.match_buffer(var_compute, (n, m))
    # with T.block("root"):
    for i, j, k in T.grid(n, m, l):
        with T.block("compute"):
            v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
            T.reads(A[v_i, v_k], B[v_j, v_k])
            T.writes(compute[v_i, v_j])
            with T.init():
                compute[v_i, v_j] = T.float32(0)
            compute[v_i, v_j] = compute[v_i, v_j] + A[v_i, v_k] * B[v_j, v_k]

或者使用装饰器的方式构建:

from tvm.topi import tag
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def compute_relu(data):
    """计算 data relu 值

    Parameters
    ----------
    data : tvm.te.Tensor
        Input argument.

    Returns
    -------
    y : tvm.te.Tensor
        The result.
    """
    print(type(data))
    return te.compute(data.shape, lambda *i: tvm.tir.Select(data(*i) < 0, 0.0, data(*i)))

data = te.placeholder(shape=(2,), dtype="float32", name="data")
out = compute_relu(data)
te.create_prim_func([data, out]).show()
<class 'tvm.te.tensor.Tensor'>
# from tvm.script import tir as T

@T.prim_func
def main(data: T.Buffer((2,), "float32"), compute: T.Buffer((2,), "float32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i0 in range(2):
        with T.block("compute"):
            v_i0 = T.axis.spatial(2, i0)
            T.reads(data[v_i0])
            T.writes(compute[v_i0])
            compute[v_i0] = T.Select(data[v_i0] < T.float32(0), T.float32(0), data[v_i0])

tag_scope conv#

@tvm.te.tag_scope(tag="conv")
def compute_conv(data, weight):
    N, IC, H, W = data.shape
    OC, IC, KH, KW = weight.shape
    OH = H - KH + 1
    OW = W - KW + 1

    ic = te.reduce_axis((0, IC), name="ic")
    dh = te.reduce_axis((0, KH), name="dh")
    dw = te.reduce_axis((0, KW), name="dw")

    return te.compute(
        (N, OC, OH, OW),
        lambda i, oc, h, w: te.sum(
            data[i, ic, h + dh, w + dw] * weight[oc, ic, dh, dw], axis=[ic, dh, dw]
        ),
    )
import json
n = te.size_var("n")
m = te.size_var("m")
l = te.size_var("l")

A = te.placeholder((n, l), name="A")
B = te.placeholder((m, l), name="B")
with tvm.te.tag_scope(tag="gemm"):
    k = te.reduce_axis((0, l), name="k")
    C = te.compute(
        (n, m),
        lambda i, j: te.sum(A[i, k] * B[j, k], axis=k),
        attrs={"hello": 1, "arr": [10, 12]},
    )

assert C.op.tag == "gemm"
assert "hello" in C.op.attrs
assert "xx" not in C.op.attrs
assert C.op.attrs["hello"].value == 1
CC = tvm.ir.load_json(tvm.ir.save_json(C))
assert CC.op.attrs["hello"].value == 1
assert CC.op.attrs["arr"][0].value == 10
# str format happened to be json compatible
assert json.loads(str(CC.op.attrs))["arr"][1] == 12
n = te.size_var("n")
c = te.size_var("c")
h = te.size_var("h")
w = te.size_var("w")
kh = te.size_var("kh")
kw = te.size_var("kw")

A = te.placeholder((n, c, h, w), name="A")
B = te.placeholder((c, c, kh, kw), name="B")
C = compute_conv(A, B)
assert C.op.tag == "conv"
assert len(C.op.attrs) == 0

嵌套:

n = te.size_var("n")
c = te.size_var("c")
h = te.size_var("h")
w = te.size_var("w")
kh = te.size_var("kh")
kw = te.size_var("kw")

A = te.placeholder((n, c, h, w), name="A")
B = te.placeholder((c, c, kh, kw), name="B")
try:
    with te.tag_scope(tag="conv"):
        C = compute_conv(A, B)
    assert False
except ValueError:
    pass
/media/pc/data/lxw/ai/tvm/python/tvm/te/tag.py:50: UserWarning: Tag 'conv' declared via TagScope was not used.
  warnings.warn(f"Tag '{self.tag}' declared via TagScope was not used.")