通用矩阵乘法(VTA)

通用矩阵乘法(VTA)#

import numpy as np
import tvm
from tvm import te
from tvm import rpc
from tvm.contrib.utils import tempdir
from vta.testing import simulator
import vta.testing
env = vta.get_env()
assert env.TARGET == "sim" and simulator.enabled()
remote = rpc.LocalSession()
batch_size, channel, block = 128, 128, 128
data_shape = (batch_size // env.BATCH, channel // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
weight_shape = (
    channel // env.BLOCK_OUT,
    channel // env.BLOCK_IN,
    env.BLOCK_OUT,
    env.BLOCK_IN,
)
res_shape = (batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * channel * channel * batch_size
ko = te.reduce_axis((0, channel // env.BLOCK_IN), name="ko")
ki = te.reduce_axis((0, env.BLOCK_IN), name="ki")

data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)
data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")
res_gem = te.compute(
    res_shape,
    lambda bo, co, bi, ci: te.sum(
        data_buf[bo, ko, bi, ki].astype(env.acc_dtype)
        * weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
        axis=[ko, ki],
    ),
    name="res_gem",
)
res_shf = te.compute(res_shape, lambda *i: res_gem(*i) >> 8, name="res_shf")
res_max = te.compute(res_shape, lambda *i: tvm.te.max(res_shf(*i), 0), "res_max")  # relu
res_min = te.compute(
    res_shape, lambda *i: tvm.te.min(res_max(*i), (1 << (env.INP_WIDTH - 1)) - 1), "res_min"
)  # relu
res = te.compute(res_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")
def verify(s):
    mod = vta.build(
        s,
        [data, weight, res],
        tvm.target.Target("ext_dev", host=env.target_host),
        name="gemm",
    )
    temp = tempdir()
    mod.save(temp.relpath("gemm.o"))
    remote.upload(temp.relpath("gemm.o"))
    f = remote.load_module("gemm.o")
    # verify
    dev = remote.ext_dev(0)
    # Data in original format
    data_orig = np.random.randint(-128, 128, size=(batch_size, channel)).astype(data.dtype)
    weight_orig = np.random.randint(-128, 128, size=(channel, channel)).astype(weight.dtype)
    data_packed = data_orig.reshape(
        batch_size // env.BATCH, env.BATCH, channel // env.BLOCK_IN, env.BLOCK_IN
    ).transpose((0, 2, 1, 3))
    weight_packed = weight_orig.reshape(
        channel // env.BLOCK_OUT, env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_IN
    ).transpose((0, 2, 1, 3))
    res_np = np.zeros(res_shape).astype(res.dtype)
    data_arr = tvm.nd.array(data_packed, dev)
    weight_arr = tvm.nd.array(weight_packed, dev)
    res_arr = tvm.nd.array(res_np, dev)
    res_ref = np.zeros(res_shape).astype(env.acc_dtype)
    for b in range(batch_size // env.BATCH):
        for i in range(channel // env.BLOCK_OUT):
            for j in range(channel // env.BLOCK_IN):
                res_ref[b, i, :] += np.dot(
                    data_packed[b, j, :].astype(env.acc_dtype),
                    weight_packed[i, j].T.astype(env.acc_dtype),
                )
    res_ref = np.right_shift(res_ref, 8)
    res_ref = np.clip(res_ref, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype)
    time_f = f.time_evaluator("gemm", dev, number=20)
    if env.TARGET in ["sim", "tsim"]:
        simulator.clear_stats()
    cost = time_f(data_arr, weight_arr, res_arr)
    if env.TARGET in ["sim", "tsim"]:
        stats = simulator.stats()
        print("Execution statistics:")
        for k, v in stats.items():
            print("\t{:<16}: {:>16}".format(k, v))
    res_unpack = res_arr.numpy().reshape(
        batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT
    )
    return cost

def run_schedule(load_inp, load_wgt, gemm, alu, store_out, print_ir, block):
    s = te.create_schedule(res.op)
    s[data_buf].set_scope(env.inp_scope)
    s[weight_buf].set_scope(env.wgt_scope)
    s[res_gem].set_scope(env.acc_scope)
    s[res_shf].set_scope(env.acc_scope)
    s[res_min].set_scope(env.acc_scope)
    s[res_max].set_scope(env.acc_scope)

    if block:
        bblock = block // env.BATCH
        iblock = block // env.BLOCK_IN
        oblock = block // env.BLOCK_OUT
        xbo, xco, xbi, xci = s[res].op.axis
        xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
        store_pt = xb2

        s[res_gem].compute_at(s[res], xco1)
        s[res_shf].compute_at(s[res], xco1)
        s[res_min].compute_at(s[res], xco1)
        s[res_max].compute_at(s[res], xco1)

        xbo, xco, xbi, xci = s[res_gem].op.axis
        # Compute one line at a time
        ko1, ko2 = s[res_gem].split(ko, iblock)
        s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)
        s[data_buf].compute_at(s[res_gem], ko1)
        s[weight_buf].compute_at(s[res_gem], ko1)
        # Use VTA instructions
        s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
        s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
        s[res_gem].tensorize(xbi, gemm)
        s[res_shf].pragma(s[res_shf].op.axis[0], alu)
        s[res_min].pragma(s[res_min].op.axis[0], alu)
        s[res_max].pragma(s[res_max].op.axis[0], alu)
        s[res].pragma(store_pt, store_out)
    else:
        xbo, xco, xbi, xci = s[res_gem].op.axis
        s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)
        # Use VTA instructions
        s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
        s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
        s[res_gem].tensorize(xbi, gemm)
        s[res_shf].pragma(s[res_shf].op.axis[0], alu)
        s[res_min].pragma(s[res_min].op.axis[0], alu)
        s[res_max].pragma(s[res_max].op.axis[0], alu)
        s[res].pragma(s[res].op.axis[0], store_out)

    if print_ir:
        print(tvm.lower(s, [data, weight, res], simple_mode=True))
    return verify(s)

GEMM GOPS End-to-End Test:

mock = env.mock
with vta.build_config():
    cost = run_schedule(
        env.dma_copy,
        env.dma_copy,
        env.gemm,
        env.alu,
        env.dma_copy,
        print_ir=False,
        block=block
    )
    gops = (num_ops / cost.mean) / float(10**9)
    print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
Execution statistics:
	inp_load_nbytes :           344064
	wgt_load_nbytes :           344064
	acc_load_nbytes :                0
	uop_load_nbytes :             1008
	out_store_nbytes:           344064
	gemm_counter    :           172032
	alu_counter     :            64512
	Time cost = 0.00169099 sec/op, 2.48038 GOPS
[08:32:41] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 08:32:42.101 INFO load_module /tmp/tmp8u11kql8/gemm.o

GEMM Unit Test:

mock = env.mock
with vta.build_config():
    cost = run_schedule(
        mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
        print_ir=False,
        block=block
    )
    gops = (num_ops / cost.mean) / float(10**9)
    print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
Execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :                0
	uop_load_nbytes :              756
	out_store_nbytes:                0
	gemm_counter    :           172032
	alu_counter     :                0
	Time cost = 0.00688763 sec/op, 0.608962 GOPS
[08:34:29] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 08:34:29.973 INFO load_module /tmp/tmp8u11kql8/gemm.o

ALU 测试:

mock = env.mock
with vta.build_config():
    cost = run_schedule(
        mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
        print_ir=False,
        block=block
    )
    gops = (num_ops / cost.mean) / float(10**9)
    print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
Execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :                0
	uop_load_nbytes :              252
	out_store_nbytes:                0
	gemm_counter    :                0
	alu_counter     :            64512
	Time cost = 0.000132332 sec/op, 31.6953 GOPS
[08:33:08] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 08:33:08.365 INFO load_module /tmp/tmp8u11kql8/gemm.o

LoadInp Unit Test:

mock = env.mock
with vta.build_config():
    cost = run_schedule(
        env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy,
        print_ir=False,
        block=block
    )
    gops = (num_ops / cost.mean) / float(10**9)
    gops = (num_ops / cost.mean) / float(10**9)
    bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10**9)
    print(
        "\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
        % (cost.mean, gops, bandwith)
    )
Execution statistics:
	inp_load_nbytes :           344064
	wgt_load_nbytes :                0
	acc_load_nbytes :                0
	uop_load_nbytes :                0
	out_store_nbytes:                0
	gemm_counter    :                0
	alu_counter     :                0
	Time cost = 2.45895e-06 sec/op, 1705.73 GOPS, bandwidth=53.3041 Gbits
[08:36:33] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 08:36:33.333 INFO load_module /tmp/tmp8u11kql8/gemm.o

LoadWgt Unit Test:

mock = env.mock
with vta.build_config():
    cost = run_schedule(
        mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy,
        print_ir=False,
        block=block
    )
    gops = (num_ops / cost.mean) / float(10**9)
    gops = (num_ops / cost.mean) / float(10**9)
    bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10**9)
    print(
        "\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
        % (cost.mean, gops, bandwith)
    )
Execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :           344064
	acc_load_nbytes :                0
	uop_load_nbytes :                0
	out_store_nbytes:                0
	gemm_counter    :                0
	alu_counter     :                0
	Time cost = 2.4185e-06 sec/op, 1734.26 GOPS, bandwidth=54.1956 Gbits
[08:37:20] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 08:37:20.333 INFO load_module /tmp/tmp8u11kql8/gemm.o

StoreOut Unit Test:

mock = env.mock
with vta.build_config():
    cost = run_schedule(
        mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
        print_ir=False,
        block=block
    )
    gops = (num_ops / cost.mean) / float(10**9)
    gops = (num_ops / cost.mean) / float(10**9)
    bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10**9)
    print(
        "\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
        % (cost.mean, gops, bandwith)
    )
Execution statistics:
	inp_load_nbytes :                0
	wgt_load_nbytes :                0
	acc_load_nbytes :                0
	uop_load_nbytes :                0
	out_store_nbytes:           344064
	gemm_counter    :                0
	alu_counter     :                0
	Time cost = 2.62682e-05 sec/op, 159.672 GOPS, bandwidth=4.98975 Gbits
[08:38:14] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=64
2023-09-25 08:38:14.909 INFO load_module /tmp/tmp8u11kql8/gemm.o