VTA GEMM

VTA GEMM#

import tvm
from tvm import te
import numpy as np
from tvm import topi
from tvm.contrib.utils import tempdir

import vta
import vta.testing
from vta.testing import simulator

np.random.seed(0xDEADB)
def _run(env, remote):
    # declare
    o = 4
    n = 1
    m = 4
    x = te.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
    w = te.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
    x_buf = te.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
    w_buf = te.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
    ko = te.reduce_axis((0, n), name="ko")
    ki = te.reduce_axis((0, env.BLOCK_IN), name="ki")
    y_gem = te.compute(
        (o, m, env.BATCH, env.BLOCK_OUT),
        lambda bo, co, bi, ci: te.sum(
            x_buf[bo, ko, bi, ki].astype(env.acc_dtype)
            * w_buf[co, ko, ci, ki].astype(env.acc_dtype),
            axis=[ko, ki],
        ),
        name="y_gem",
    )
    y_shf = te.compute(
        (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: y_gem(*i) >> 8, name="y_shf"
    )
    y_max = te.compute(
        (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.te.max(y_shf(*i), 0), "y_max"
    )  # relu
    y_min = te.compute(
        (o, m, env.BATCH, env.BLOCK_OUT),
        lambda *i: tvm.te.min(y_max(*i), (1 << (env.INP_WIDTH - 1)) - 1),
        "y_min",
    )  # relu
    y = te.compute(
        (o, m, env.BATCH, env.BLOCK_OUT), lambda *i: y_min(*i).astype(env.inp_dtype), name="y"
    )

    if not remote:
        return

    def verify(s, name=None):
        # Build with the CSE pass disabled as otherwise it would complicate the test
        with vta.build_config(disabled_pass={"tir.CommonSubexprElimTIR"}):
            mod = vta.build(s, [x, w, y], tvm.target.Target("ext_dev", host=env.target_host))
        temp = tempdir()
        mod.save(temp.relpath("gemm.o"))
        remote.upload(temp.relpath("gemm.o"))
        f = remote.load_module("gemm.o")
        # verify
        dev = remote.ext_dev(0)
        x_np = np.random.randint(-128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(
            x.dtype
        )
        w_np = np.random.randint(-128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(
            w.dtype
        )
        y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
        x_nd = tvm.nd.array(x_np, dev)
        w_nd = tvm.nd.array(w_np, dev)
        y_nd = tvm.nd.array(y_np, dev)
        y_np = y_np.astype(env.acc_dtype)
        for b in range(o):
            for i in range(m):
                for j in range(n):
                    y_np[b, i, :] += np.dot(
                        x_np[b, j, :].astype(env.acc_dtype), w_np[i, j].T.astype(env.acc_dtype)
                    )
        y_np = np.right_shift(y_np, 8)
        y_np = np.clip(y_np, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(y.dtype)

        if env.TARGET in ["sim", "tsim"]:
            simulator.clear_stats()

        f(x_nd, w_nd, y_nd)

        np.testing.assert_equal(y_np, y_nd.numpy())

        if env.TARGET in ["sim", "tsim"]:
            sim_stats = simulator.stats()
            print("GEMM schedule:{} execution statistics:".format(name))
            for k, v in sim_stats.items():
                print("\t{:<16}: {:>16}".format(k, v))

    def test_schedule1():
        # default schedule with no smt
        s = te.create_schedule(y.op)
        # set the scope of the SRAM buffers
        s[x_buf].set_scope(env.inp_scope)
        s[w_buf].set_scope(env.wgt_scope)
        s[y_gem].set_scope(env.acc_scope)
        s[y_shf].set_scope(env.acc_scope)
        s[y_max].set_scope(env.acc_scope)
        s[y_min].set_scope(env.acc_scope)
        # set pragmas for DMA transfer and ALU ops
        s[x_buf].compute_at(s[y_gem], ko)
        s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
        s[w_buf].compute_at(s[y_gem], ko)
        s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
        s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
        s[y_max].pragma(s[y_max].op.axis[0], env.alu)
        s[y_min].pragma(s[y_min].op.axis[0], env.alu)
        s[y].pragma(s[y].op.axis[0], env.dma_copy)
        # tensorization
        s[y_gem].reorder(
            ko,
            s[y_gem].op.axis[0],
            s[y_gem].op.axis[1],
            s[y_gem].op.axis[2],
            s[y_gem].op.axis[3],
            ki,
        )
        s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
        verify(s, name="default")

    def test_smt():
        # test smt schedule
        s = te.create_schedule(y.op)
        s[x_buf].set_scope(env.inp_scope)
        s[w_buf].set_scope(env.wgt_scope)
        s[y_gem].set_scope(env.acc_scope)
        s[y_shf].set_scope(env.acc_scope)
        s[y_max].set_scope(env.acc_scope)
        s[y_min].set_scope(env.acc_scope)
        abo, aco, abi, aci = s[y].op.axis
        abo1, abo2 = s[y].split(abo, nparts=2)
        s[y].bind(abo1, te.thread_axis("cthread"))
        s[y_gem].compute_at(s[y], abo1)
        s[y_shf].compute_at(s[y], abo1)
        s[y_max].compute_at(s[y], abo1)
        s[y_min].compute_at(s[y], abo1)
        s[y_gem].reorder(
            ko,
            s[y_gem].op.axis[0],
            s[y_gem].op.axis[1],
            s[y_gem].op.axis[2],
            s[y_gem].op.axis[3],
            ki,
        )
        s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
        s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
        s[y_max].pragma(s[y_max].op.axis[0], env.alu)
        s[y_min].pragma(s[y_min].op.axis[0], env.alu)
        s[x_buf].compute_at(s[y_gem], ko)
        s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
        s[w_buf].compute_at(s[y_gem], ko)
        s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
        s[y].pragma(abo2, env.dma_copy)
        verify(s, name="smt")

    test_schedule1()
    test_smt()

vta.testing.run(_run)
GEMM schedule:default execution statistics:
	inp_load_nbytes :               64
	wgt_load_nbytes :             1024
	acc_load_nbytes :                0
	uop_load_nbytes :               20
	out_store_nbytes:              256
	gemm_counter    :               16
	alu_counter     :               48
GEMM schedule:smt execution statistics:
	inp_load_nbytes :               64
	wgt_load_nbytes :             2048
	acc_load_nbytes :                0
	uop_load_nbytes :               40
	out_store_nbytes:              256
	gemm_counter    :               16
	alu_counter     :               48
[13:11:21] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 13:11:21.765 INFO load_module /tmp/tmpmaomvjg0/gemm.o
[13:11:21] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 13:11:22.029 INFO load_module /tmp/tmpmaomvjg0/gemm.o