VTA topi.conv2d

VTA topi.conv2d#

import numpy as np
from collections import namedtuple

import tvm
from tvm import te
from tvm import relay
from tvm import autotvm
from tvm.contrib.utils import tempdir
# from tvm.contrib.pickle_memoize import memoize
from tvm import topi
import tvm.topi.testing
import vta
import vta.testing
from vta.testing import simulator

# FIXME: 需要自定义 clip 算子来规避某种模式检测的限制。
@tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
    """与 topi 当前的 clip 不同,将最小值和最大值分为两个阶段。"""
    const_min = tvm.tir.const(a_min, x.dtype)
    const_max = tvm.tir.const(a_max, x.dtype)
    x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
    x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
    return x
Workload = namedtuple(
    "Conv2DWorkload",
    [
        "batch",
        "height",
        "width",
        "in_filter",
        "out_filter",
        "hkernel",
        "wkernel",
        "hpad",
        "wpad",
        "hstride",
        "wstride",
    ],
)
env = vta.get_env()
remote = tvm.rpc.LocalSession()
# ResNet18 workloads
resnet_wkls = [
    # Workloads of resnet18 on imagenet
    # ('resnet-18.C1',  Workload(env.BATCH, 224, 224, 3,   64,  7, 7, 3, 3, 2, 2)),
    ("resnet-18.C2", Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
    # ("resnet-18.C3", Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
    # ("resnet-18.C4", Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
    # ("resnet-18.C5", Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
    # ("resnet-18.C6", Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
    # ("resnet-18.C7", Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
    # ("resnet-18.C8", Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
    # ("resnet-18.C9", Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
    # ("resnet-18.C10", Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
    # ("resnet-18.C11", Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]
with autotvm.tophub.context(env.target):
    for _, wl in resnet_wkls:
        print(wl)
        assert wl.hpad == wl.wpad
        layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
        conv2d_fcompute = vta.top.conv2d_packed
        conv2d_fschedule = vta.top.schedule_conv2d_packed
        # Derive shapes depending upon packing
        a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
        w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
        b_shape = (wl.batch, wl.out_filter, 1, 1)
        # data pack
        data_shape = (
            wl.batch // env.BATCH,
            wl.in_filter // env.BLOCK_IN,
            wl.height,
            wl.width,
            env.BATCH,
            env.BLOCK_IN,
        )
        kernel_shape = (
            wl.out_filter // env.BLOCK_OUT,
            wl.in_filter // env.BLOCK_IN,
            wl.hkernel,
            wl.wkernel,
            env.BLOCK_OUT,
            env.BLOCK_IN,
        )
        bias_shape = (
            wl.batch // env.BATCH,
            wl.out_filter // env.BLOCK_OUT,
            1,
            1,
            env.BATCH,
            env.BLOCK_OUT,
        )
        data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
        kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
        bias = te.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
        padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
        # Define base computation schedule
        with env.target:
            res = conv2d_fcompute(
                data, kernel, (wl.hstride, wl.wstride), padding, (1, 1), layout, env.acc_dtype
            )
            
            res = topi.right_shift(res, 8)
            res = topi.add(res, bias)
            res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
            res = topi.cast(res, env.out_dtype)
            # Derive base schedule
            s = conv2d_fschedule([res])
            # print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
        # Derive number of ops
        fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
        fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
        num_ops = (
            2
            * wl.batch
            * fout_height
            * fout_width
            * wl.hkernel
            * wl.wkernel
            * wl.out_filter
            * wl.in_filter
        )

        # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
        def get_ref_data():
            # derive min max for act, wgt, and bias types (max non inclusive)
            a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
            w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
            b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (
                env.INP_WIDTH + env.WGT_WIDTH - 2
            )
            a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
            w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)
            b_np = np.random.randint(b_min, b_max, size=b_shape).astype(env.acc_dtype)
            r_np = tvm.topi.testing.conv2d_nchw_python(
                a_np.astype(env.acc_dtype),
                w_np.astype(env.acc_dtype),
                (wl.hstride, wl.wstride),
                wl.hpad,
            ).astype(env.acc_dtype)
            return a_np, w_np, b_np, r_np

        # Data in original format
        data_np, kernel_np, bias_np, res_ref = get_ref_data()
        # data pack
        data_np = data_np.reshape(
            wl.batch // env.BATCH,
            env.BATCH,
            wl.in_filter // env.BLOCK_IN,
            env.BLOCK_IN,
            wl.height,
            wl.width,
        ).transpose((0, 2, 4, 5, 1, 3))
        kernel_np = kernel_np.reshape(
            wl.out_filter // env.BLOCK_OUT,
            env.BLOCK_OUT,
            wl.in_filter // env.BLOCK_IN,
            env.BLOCK_IN,
            wl.hkernel,
            wl.wkernel,
        ).transpose((0, 2, 4, 5, 1, 3))
        bias_np = bias_np.reshape(
            wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT
        )
        # build
        with vta.build_config(disabled_pass={"tir.CommonSubexprElimTIR"}):
            mod = vta.build(
                s,
                [data, kernel, bias, res],
                target=tvm.target.Target(env.target, host=env.target_host),
                name="conv2d",
            )

        temp = tempdir()
        mod.save(temp.relpath("conv2d.o"))
        remote.upload(temp.relpath("conv2d.o"))
        f = remote.load_module("conv2d.o")
        dev = remote.device(str(env.target))

        res_np = np.zeros(topi.utils.get_const_tuple(res.shape)).astype(res.dtype)
        data_arr = tvm.nd.array(data_np, dev)
        kernel_arr = tvm.nd.array(kernel_np, dev)
        bias_arr = tvm.nd.array(bias_np, dev)
        res_arr = tvm.nd.array(res_np, dev)
        time_f = f.time_evaluator("conv2d", dev, number=4)

        simulator.clear_stats()
        cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
        stats = simulator.stats()

        # 正确性
        res_orig = res_arr.numpy()
        # data pack
        res_orig = res_orig.transpose((0, 4, 1, 5, 2, 3)).reshape(
            wl.batch, wl.out_filter, fout_height, fout_width
        )
        bias_np = bias_np.transpose((0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1)
        res_ref = res_ref >> env.WGT_WIDTH
        res_ref += bias_np
        res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
        res_ref = res_ref.astype(env.out_dtype)
        correct = np.allclose(res_orig, res_ref)

        # 打印
        gops = (num_ops / cost.mean) / float(10**9)
        print(f"CONV2D TEST: Time cost = {cost.mean:g} sec/op, {gops:g} GOPS")
Conv2DWorkload(batch=1, height=56, width=56, in_filter=64, out_filter=64, hkernel=3, wkernel=3, hpad=1, wpad=1, hstride=1, wstride=1)
CONV2D TEST: Time cost = 0.0928819 sec/op, 2.4893 GOPS
[17:25:56] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 17:25:56.273 INFO load_module /tmp/tmper0fe63q/conv2d.o