VTA save/store 输出命令

VTA save/store 输出命令#

import numpy as np
import tvm
from tvm import te
from tvm.contrib.utils import tempdir

import vta
import vta.testing
from vta.testing import simulator

np.random.seed(0xDEADB)
def _run(env, remote):
    n = 6
    x = te.placeholder((n, n, env.BATCH, env.BLOCK_OUT), name="x", dtype=env.acc_dtype)
    x_buf = te.compute((n, n, env.BATCH, env.BLOCK_OUT), lambda *i: x(*i), "x_buf")
    # 插入不会被优化掉的 no-op
    y_buf = te.compute((n, n, env.BATCH, env.BLOCK_OUT), lambda *i: x_buf(*i) >> 0, "y_buf")
    y = te.compute(
        (n, n, env.BATCH, env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y"
    )
    # schedule
    s = te.create_schedule(y.op)
    s[x_buf].set_scope(env.acc_scope)
    s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
    s[y_buf].set_scope(env.acc_scope)
    s[y_buf].pragma(y_buf.op.axis[0], env.alu)
    s[y].pragma(y.op.axis[0], env.dma_copy)

    # 构建库
    with vta.build_config():
        m = vta.build(s, [x, y], tvm.target.Target("ext_dev", host=env.target_host))

    temp = tempdir()
    m.save(temp.relpath("load_act.o"))
    remote.upload(temp.relpath("load_act.o"))
    f = remote.load_module("load_act.o")
    # 验证
    dev = remote.ext_dev(0)
    x_np = np.random.randint(1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
    y_np = x_np.astype(y.dtype)
    x_nd = tvm.nd.array(x_np, dev)
    y_nd = tvm.nd.empty(y_np.shape, device=dev, dtype=y_np.dtype)

    assert env.TARGET in ["sim", "tsim"]
    simulator.clear_stats()

    f(x_nd, y_nd)

    np.testing.assert_equal(y_np, y_nd.numpy())

    sim_stats = simulator.stats()
    print("Save load execution statistics:")
    for k, v in sim_stats.items():
        print("\t{:<16}: {:>16}".format(k, v))

vta.testing.run(_run)
Save load execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :             2304
	uop_load_nbytes :                4
	out_store_nbytes:              576
	gemm_counter    :                0
	alu_counter     :               36
2023-09-25 11:06:04.521 INFO load_module /tmp/tmp6wck92fv/load_act.o
out_store_nbytes = np.prod([n, n, env.BATCH, env.BLOCK_OUT])
out_store_nbytes, out_store_nbytes*4
(576, 2304)