RELU on ALU

RELU on ALU#

import tvm
from tvm import te
import numpy as np
from tvm import topi
from tvm.contrib.utils import tempdir

import vta
import vta.testing
from vta.testing import simulator

np.random.seed(0xDEADB)
def _run(env, remote):
    m = 8
    n = 10
    # compute
    a = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="a", dtype=env.acc_dtype)
    a_buf = te.compute(
        (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a(*i), "a_buf"
    )  # DRAM->SRAM
    max_buf = te.compute(
        (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.te.max(a_buf(*i), 0), "res_buf"
    )  # relu
    min_buf = te.compute(
        (m, n, env.BATCH, env.BLOCK_OUT),
        lambda *i: tvm.te.min(max_buf(*i), (1 << (env.INP_WIDTH - 1)) - 1),
        "max_buf",
    )  # relu
    res = te.compute(
        (m, n, env.BATCH, env.BLOCK_OUT),
        lambda *i: min_buf(*i).astype(env.inp_dtype),
        "min_buf",
    )  # SRAM->DRAM
    # schedule
    s = te.create_schedule(res.op)
    s[a_buf].set_scope(env.acc_scope)  # SRAM
    s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy)  # DRAM->SRAM
    s[max_buf].set_scope(env.acc_scope)  # SRAM
    s[min_buf].set_scope(env.acc_scope)  # SRAM
    s[max_buf].pragma(max_buf.op.axis[0], env.alu)  # compute
    s[min_buf].pragma(min_buf.op.axis[0], env.alu)  # compute
    s[res].pragma(res.op.axis[0], env.dma_copy)  # SRAM->DRAM
    # build
    with vta.build_config():
        mod = vta.build(s, [a, res], tvm.target.Target("ext_dev", host=env.target_host))
    if not remote:
        return
    temp = tempdir()
    mod.save(temp.relpath("load_act.o"))
    remote.upload(temp.relpath("load_act.o"))
    f = remote.load_module("load_act.o")
    # verify
    dev = remote.ext_dev(0)
    a_np = np.random.randint(-256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
    res_np = np.clip(a_np, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype)
    a_nd = tvm.nd.array(a_np, dev)
    res_nd = tvm.nd.array(np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), dev)

    if env.TARGET in ["sim", "tsim"]:
        simulator.clear_stats()

    f(a_nd, res_nd)

    np.testing.assert_equal(res_np, res_nd.numpy())

    if env.TARGET in ["sim", "tsim"]:
        sim_stats = simulator.stats()
        print("Relu execution statistics:")
        for k, v in sim_stats.items():
            print("\t{:<16}: {:>16}".format(k, v))

vta.testing.run(_run)
Relu execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             5120
	uop_load_nbytes :                8
	out_store_nbytes:             1280
	gemm_counter    :                0
	alu_counter     :              160
2023-09-25 13:13:49.037 INFO load_module /tmp/tmppyrqqdcq/load_act.o