VTA padded load

VTA padded load#

import tvm
from tvm import te
import numpy as np
from tvm import topi
from tvm.contrib.utils import tempdir

import vta
import vta.testing
from vta.testing import simulator

np.random.seed(0xDEADB)
def _run(env, remote):
    def check_padded_load(pad_before, pad_after, test_name=None):
        # declare
        n = 3
        m = 5
        x = te.placeholder((n, m, env.BATCH, env.BLOCK_OUT), name="x", dtype=env.acc_dtype)
        x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
        # insert no-op that won't be optimized away
        y_buf = te.compute(
            (
                n + pad_before[0] + pad_after[0],
                m + pad_before[1] + pad_after[1],
                env.BATCH,
                env.BLOCK_OUT,
            ),
            lambda *i: x_buf(*i) >> 0,
            "y_buf",
        )
        y = te.compute(
            (
                n + pad_before[0] + pad_after[0],
                m + pad_before[1] + pad_after[1],
                env.BATCH,
                env.BLOCK_OUT,
            ),
            lambda *i: y_buf(*i).astype(env.inp_dtype),
            "y",
        )
        # schedule
        s = te.create_schedule(y.op)
        s[x_buf].set_scope(env.acc_scope)
        s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
        s[y_buf].set_scope(env.acc_scope)
        s[y_buf].pragma(y_buf.op.axis[0], env.alu)
        s[y].pragma(y.op.axis[0], env.dma_copy)
        # build
        with vta.build_config():
            mod = vta.build(s, [x, y], tvm.target.Target("ext_dev", host=env.target_host))

        if not remote:
            return
        temp = tempdir()
        mod.save(temp.relpath("padded_load.o"))
        remote.upload(temp.relpath("padded_load.o"))
        f = remote.load_module("padded_load.o")
        # verify
        dev = remote.ext_dev(0)
        x_np = np.random.randint(0, 10, size=(n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
        y_np = np.zeros(
            (
                n + pad_before[0] + pad_after[0],
                m + pad_before[1] + pad_after[1],
                env.BATCH,
                env.BLOCK_OUT,
            )
        ).astype(y.dtype)
        y_np[pad_before[0] : pad_before[0] + n, pad_before[1] : pad_before[1] + m, :] = x_np
        x_nd = tvm.nd.array(x_np, dev)
        y_nd = tvm.nd.empty(y_np.shape, device=dev, dtype=y_np.dtype)

        if env.TARGET in ["sim", "tsim"]:
            simulator.clear_stats()

        f(x_nd, y_nd)

        np.testing.assert_equal(y_np, y_nd.numpy())

        if env.TARGET in ["sim", "tsim"]:
            sim_stats = simulator.stats()
            print(f"Padded {test_name} load execution statistics:")
            for k, v in sim_stats.items():
                print("\t{:<16}: {:>16}".format(k, v))

    check_padded_load([2, 0, 0, 0], [0, 0, 0, 0], test_name="Y0")
    check_padded_load([0, 2, 0, 0], [0, 0, 0, 0], test_name="Y1")
    check_padded_load([0, 0, 0, 0], [2, 0, 0, 0], test_name="X0")
    check_padded_load([0, 0, 0, 0], [0, 2, 0, 0], test_name="X1")
    check_padded_load([1, 1, 0, 0], [1, 1, 0, 0], test_name="all")

vta.testing.run(_run)
Padded Y0 load execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :              960
	uop_load_nbytes :                4
	out_store_nbytes:              400
	gemm_counter    :                0
	alu_counter     :               25
Padded Y1 load execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :              960
	uop_load_nbytes :                4
	out_store_nbytes:              336
	gemm_counter    :                0
	alu_counter     :               21
Padded X0 load execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :              960
	uop_load_nbytes :                4
	out_store_nbytes:              400
	gemm_counter    :                0
	alu_counter     :               25
Padded X1 load execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :              960
	uop_load_nbytes :                4
	out_store_nbytes:              336
	gemm_counter    :                0
	alu_counter     :               21
Padded all load execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :              960
	uop_load_nbytes :                4
	out_store_nbytes:              560
	gemm_counter    :                0
	alu_counter     :               35
2023-09-25 13:10:39.193 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:39.453 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:39.737 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:39.993 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:40.310 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o