测试定点乘法#

import numpy as np
import tvm
from tvm import relay, te
from tvm.relay.backend import Executor

# def get_hexagon_target(cpu_ver: str, **kwargs) -> tvm.target.Target:
#     """Creates a Hexagon target"""
#     target = tvm.target.hexagon(cpu_ver, **kwargs)
#     return tvm.target.Target(target, host=target)

def build_module(relay_mod, target):
    params = {}
    executor = Executor("aot", {"link-params": True})
    lowered = tvm.relay.build(
        relay_mod,
        tvm.target.Target(target, host=target),
        executor=executor,
        params=params,
    )
    return lowered

def run_module(mod, inputs):
    mod.set_input(**inputs)
    mod.run()
    output = mod.get_output(0).numpy()
    return output

测试 relay.fixed_point_multiply#

ishape = (6, 32)
a = relay.var("a", relay.TensorType(ishape, "int32"))
for multiplier, shift in [
    (1288490240, -2),  # 0.15
    (1395864320, 1),  # 1.3
    (1288490188, 0),  # 0.6
]:
    fpm = relay.fixed_point_multiply(a, multiplier, shift)
    relay_mod = tvm.IRModule.from_expr(fpm)
    relay_mod.show()
    with tvm.transform.PassContext(opt_level=3):
        # Compile for LLVM...
        llvm_lowered = build_module(relay_mod, tvm.target.Target("llvm"))

    data_in = np.arange(-96, 96).reshape(ishape)
    inputs = {"a": data_in}

    # Run llvm...
    llvm_mod = tvm.runtime.executor.AotModule(llvm_lowered["default"](tvm.cpu(0)))
    expected_output = run_module(llvm_mod, inputs)
    # print(expected_output)
def @main(%a: Tensor[(6, 32), int32]) {
  fixed_point_multiply(%a, multiplier=1288490240, shift=-2)
}
def @main(%a: Tensor[(6, 32), int32]) {
  fixed_point_multiply(%a, multiplier=1395864320, shift=1)
}
def @main(%a: Tensor[(6, 32), int32]) {
  fixed_point_multiply(%a, multiplier=1288490188, shift=0)
}

逐通道定点乘法#

scales = (
    (1.3, 30.0),
    (1.37, 1.0),
    (0.6, 1.0),
    ((1.7, 0.6), 1.0),
    ((0.007, 1.9), 1.0),
)
ishape = [1, 128, 56, 56]
axis = 1
a = relay.var("a", shape=ishape, dtype="int32")
for in_scale_const, out_scale_const in scales:
    # Make list of input scales from in_scale_const parameter.
    if isinstance(in_scale_const, tuple):
        in_scale = list(in_scale_const) * (ishape[axis] // len(in_scale_const))
    else:
        in_scale = [in_scale_const] * ishape[axis]
    assert len(in_scale) == ishape[axis]
    # qnn.requantize is lowered to fixed_point_multiply if zp == 0 and in_dtype == out_dtype.
    iscale = relay.const(in_scale)
    izero = relay.const(0)
    oscale = relay.const(out_scale_const)
    ozero = relay.const(0)
    op = relay.qnn.op.requantize(a, iscale, izero, oscale, ozero, axis=axis, out_dtype="int32")
    mod = tvm.IRModule.from_expr(op)
    mod = relay.transform.InferType()(mod)
    mod.show()
    with tvm.transform.PassContext(opt_level=3):
        # Compile for LLVM...
        llvm_lowered = build_module(mod, tvm.target.Target("llvm"))

    a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape)
    inputs = {"a": a_np}

    # Run llvm...
    llvm_mod = tvm.runtime.executor.AotModule(llvm_lowered["default"](tvm.cpu(0)))
    expected_output = run_module(llvm_mod, inputs)
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
  qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 30f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
  qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
  qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
  qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
  qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}

fixed point multiply with vectorization#

Vectorization size is more than hw vector length

ishape = [2, 256, 16]

def q_mul_shift(shape):
    x = te.placeholder(shape, name="X", dtype="int32")
    out = te.compute(
        shape,
        lambda i, j, k: tvm.tir.q_multiply_shift(
            x[i, j, k],
            tvm.tir.const(1395864320, "int32"),
            tvm.tir.const(31, "int32"),
            tvm.tir.const(1, "int32"),
        ),
        name="compute",
    )
    return te.create_prim_func([x, out])

for vector_size in (32, 64, 128, 256):
    mod = q_mul_shift(ishape)
    # Schedule with vectorization
    sch = tvm.tir.Schedule(mod)
    b00 = sch.get_block(name="compute", func_name="main")
    fused = sch.fuse(*sch.get_loops(block=b00))
    _, v = sch.split(loop=fused, factors=[None, vector_size])
    sch.vectorize(v)

    with tvm.transform.PassContext(opt_level=3):
        host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))

    # Verify accuracy
    a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32")
    b_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32")

    host_args = [tvm.runtime.ndarray.array(arg) for arg in [a_np, b_np]]
    host_lib(*host_args)
a_shape = [2, 256, 16]
b_shape = [256]

def q_mul_shift(shape):
    shift_shape = [shape[1]]
    x = te.placeholder(shape, name="X", dtype="int32")
    y = te.placeholder(shift_shape, name="X", dtype="int32")
    l_shift = te.placeholder(shift_shape, name="X", dtype="int32")
    r_shift = te.placeholder(shift_shape, name="X", dtype="int32")

    out = te.compute(
        shape,
        lambda i, j, k: tvm.tir.q_multiply_shift_per_axis(
            x[i, j, k],
            y[j],
            l_shift[j],
            r_shift[j],
            tvm.tir.const(31, "int32"),
            tvm.tir.const(1, "bool"),
            tvm.tir.const(0, "bool"),
        ),
        name="compute",
    )
    return te.create_prim_func([x, y, l_shift, r_shift, out])

for vector_size in (32, 64, 128, 256):
    mod = q_mul_shift(a_shape)
    # Schedule with vectorization
    sch = tvm.tir.Schedule(mod)
    b00 = sch.get_block(name="compute", func_name="main")
    fused = sch.fuse(*sch.get_loops(block=b00))
    _, v = sch.split(loop=fused, factors=[None, vector_size])
    sch.vectorize(v)

    with tvm.transform.PassContext(opt_level=3):
        host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))

    # Verify accuracy
    x_np = (
        np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32")
    )
    y_np = (
        np.random.randint(-1000, 1000, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
    )
    lsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
    rsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
    b_np = (
        np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32")
    )
    np_args = [x_np, y_np, lsh_np, rsh_np, b_np]
    host_args = [tvm.runtime.ndarray.array(arg) for arg in np_args]
    host_lib(*host_args)
relay.fixed_point_multiply
<function tvm.relay.op.tensor.fixed_point_multiply(data, multiplier, shift)>