VTA ALU

VTA ALU#

import tvm
from tvm import te
import numpy as np
from tvm import topi
from tvm.contrib.utils import tempdir

import vta
import vta.testing
from vta.testing import simulator

np.random.seed(0xDEADB)
def _run(env, remote):
    def check_alu(tvm_op, np_op=None, use_imm=False, test_name=None):
        """Test ALU"""
        m = 8
        n = 8
        imm = np.random.randint(1, 5)
        # compute
        a = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="a", dtype=env.acc_dtype)
        a_buf = te.compute(
            (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a(*i), "a_buf"
        )  # DRAM->SRAM
        if use_imm:
            res_buf = te.compute(
                (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: tvm_op(a_buf(*i), imm), "res_buf"
            )  # compute
        else:
            b = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="b", dtype=env.acc_dtype)
            b_buf = te.compute(
                (m, n, env.BATCH, env.BLOCK_OUT), lambda *i: b(*i), "b_buf"
            )  # DRAM->SRAM
            res_buf = te.compute(
                (m, n, env.BATCH, env.BLOCK_OUT),
                lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
                "res_buf",
            )  # compute5B
        res = te.compute(
            (m, n, env.BATCH, env.BLOCK_OUT),
            lambda *i: res_buf(*i).astype(env.inp_dtype),
            "res",
        )  # SRAM->DRAM
        # schedule
        s = te.create_schedule(res.op)
        s[a_buf].set_scope(env.acc_scope)  # SRAM
        s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy)  # DRAM->SRAM
        s[res_buf].set_scope(env.acc_scope)  # SRAM
        s[res_buf].pragma(res_buf.op.axis[0], env.alu)  # compute
        s[res].pragma(res.op.axis[0], env.dma_copy)  # SRAM->DRAM
        if not use_imm:
            s[b_buf].set_scope(env.acc_scope)  # SRAM
            s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy)  # DRAM->SRAM

        if not remote:
            return

        # build
        with vta.build_config():
            if use_imm:
                mod = vta.build(s, [a, res], tvm.target.Target("ext_dev", host=env.target_host))
            else:
                mod = vta.build(
                    s, [a, b, res], tvm.target.Target("ext_dev", host=env.target_host)
                )
        temp = tempdir()
        mod.save(temp.relpath("load_act.o"))
        remote.upload(temp.relpath("load_act.o"))
        f = remote.load_module("load_act.o")
        # verify
        dev = remote.ext_dev(0)
        a_np = np.random.randint(-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
        if use_imm:
            res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
        else:
            b_np = np.random.randint(-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(
                b.dtype
            )
            res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
        res_np = res_np.astype(res.dtype)
        a_nd = tvm.nd.array(a_np, dev)
        res_nd = tvm.nd.array(np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), dev)

        if env.TARGET in ["sim", "tsim"]:
            simulator.clear_stats()

        if use_imm:
            f(a_nd, res_nd)
        else:
            b_nd = tvm.nd.array(b_np, dev)
            f(a_nd, b_nd, res_nd)

        np.testing.assert_equal(res_np, res_nd.numpy())

        if env.TARGET in ["sim", "tsim"]:
            sim_stats = simulator.stats()
            print("ALU {} execution statistics:".format(test_name))
            for k, v in sim_stats.items():
                print("\t{:<16}: {:>16}".format(k, v))

    check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL")
    check_alu(tvm.te.max, np.maximum, use_imm=True, test_name="MAX")
    check_alu(tvm.te.max, np.maximum, test_name="MAX")
    check_alu(lambda x, y: x + y, use_imm=True, test_name="ADD")
    check_alu(lambda x, y: x + y, test_name="ADD")
    check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True, test_name="SHR")

vta.testing.run(_run)
ALU SHL execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             4096
	uop_load_nbytes :                4
	out_store_nbytes:             1024
	gemm_counter    :                0
	alu_counter     :               64
ALU MAX execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             4096
	uop_load_nbytes :                4
	out_store_nbytes:             1024
	gemm_counter    :                0
	alu_counter     :               64
ALU MAX execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             8192
	uop_load_nbytes :                4
	out_store_nbytes:             1024
	gemm_counter    :                0
	alu_counter     :               64
ALU ADD execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             4096
	uop_load_nbytes :                4
	out_store_nbytes:             1024
	gemm_counter    :                0
	alu_counter     :               64
ALU ADD execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             8192
	uop_load_nbytes :                4
	out_store_nbytes:             1024
	gemm_counter    :                0
	alu_counter     :               64
ALU SHR execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             4096
	uop_load_nbytes :                4
	out_store_nbytes:             1024
	gemm_counter    :                0
	alu_counter     :               64
2023-09-25 13:12:35.136 INFO load_module /tmp/tmp7oy9i8lt/load_act.o
2023-09-25 13:12:35.287 INFO load_module /tmp/tmp7oy9i8lt/load_act.o
2023-09-25 13:12:35.458 INFO load_module /tmp/tmp7oy9i8lt/load_act.o
2023-09-25 13:12:35.609 INFO load_module /tmp/tmp7oy9i8lt/load_act.o
2023-09-25 13:12:35.780 INFO load_module /tmp/tmp7oy9i8lt/load_act.o
2023-09-25 13:12:35.933 INFO load_module /tmp/tmp7oy9i8lt/load_act.o