VTA topi.dense

VTA topi.dense#

import numpy as np

import tvm
from tvm import te
from tvm import autotvm, rpc
from tvm.contrib.utils import tempdir
# from tvm.contrib.pickle_memoize import memoize
from tvm import topi
import tvm.topi.testing
import vta
import vta.testing
from vta.testing import simulator

# FIXME: 需要自定义 clip 算子来规避某种模式检测的限制。
@tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
    """与 topi 当前的 clip 不同,将最小值和最大值分为两个阶段。"""
    const_min = tvm.tir.const(a_min, x.dtype)
    const_max = tvm.tir.const(a_max, x.dtype)
    x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
    x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
    return x
env = vta.get_env()
remote = rpc.LocalSession()
target = env.target
batch_size = 16
in_feat = 512
out_feat = 1008
a_shape = (batch_size, in_feat)
w_shape = (out_feat, in_feat)
# data pack
data_shape = (batch_size // env.BATCH, in_feat // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
kernel_shape = (
    out_feat // env.BLOCK_OUT,
    in_feat // env.BLOCK_IN,
    env.BLOCK_OUT,
    env.BLOCK_IN,
)
fcompute = vta.top.dense_packed
fschedule = vta.top.schedule_dense_packed
# 声明计算
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
# 定义调度
with target:
    res = fcompute(data, kernel, None, env.acc_dtype)
    res = topi.right_shift(res, 8)
    res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
    res = topi.cast(res, env.out_dtype)
    # Derive base schedule
    s = fschedule([res])
    # print(vta.lower(s, [data, kernel, res], simple_mode=True))

num_ops = 2 * batch_size * in_feat * out_feat # 算子数量
# @memoize("vta.tests.test_benchmark_topi.dense.verify")
def get_ref_data():
    # derive min max for act, wgt types (max non inclusive)
    a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
    w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
    a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
    w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)

    r_np = np.dot(a_np.astype(env.acc_dtype), w_np.T.astype(env.acc_dtype)).astype(
        env.acc_dtype
    )
    return a_np, w_np, r_np

data_np, kernel_np, res_ref = get_ref_data() # 原始数据
# 数据打包
data_np = data_np.reshape(
    batch_size // env.BATCH, env.BATCH, in_feat // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))
kernel_np = kernel_np.reshape(
    out_feat // env.BLOCK_OUT, env.BLOCK_OUT, in_feat // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))

# 构建库
mod = vta.build(
    s,
    [data, kernel, res],
    target=tvm.target.Target(target, host=env.target_host),
    name="dense",
)
[17:03:51] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
temp = tempdir()
mod.save(temp.relpath("dense.o"))
remote.upload(temp.relpath("dense.o"))
f = remote.load_module("dense.o")
dev = remote.device(str(target))
res_np = np.zeros(topi.utils.get_const_tuple(res.shape)).astype(res.dtype)
data_arr = tvm.nd.array(data_np, dev)
kernel_arr = tvm.nd.array(kernel_np, dev)
res_arr = tvm.nd.array(res_np, dev)
time_f = f.time_evaluator("dense", dev, number=4)
2023-09-25 17:03:52.712 INFO load_module /tmp/tmpmci4zmeb/dense.o
# In vta sim mode, collect simulator runtime statistics
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, res_arr)
stats = simulator.stats()
# 验证正确性
res_orig = res_arr.numpy()
res_orig = res_orig.reshape(batch_size, out_feat) # 数据打包
res_ref = res_ref >> 8
res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
res_ref = res_ref.astype(env.out_dtype)
correct = np.allclose(res_orig, res_ref)
gops = (num_ops / cost.mean) / float(10**9)
print(f"VTA DENSE TEST: Time cost = {cost.mean:g} sec/op, {gops: g} GOPS")
VTA DENSE TEST: Time cost = 0.0932854 sec/op,  0.177038 GOPS
[17:03:53] /media/pc/data/lxw/ai/tvm/src/runtime/profiling.cc:101: Warning: No timer implementation for ext_dev, using default timer instead. It may be inaccurate or have extra overhead.