定点乘法#

定义一些辅助函数用于推理:

import set_env
import tvm
from tvm import relay
import numpy as np

def update_lib(lib, lib_name="lib.so"):
    tmp_path = tvm.contrib.utils.tempdir()
    lib_path = tmp_path.relpath(lib_name)
    lib.export_library(lib_path, fcompile=False)
    return tvm.runtime.load_module(lib_path)

def run_llvm(run_mod, params, input_dict, lib_name="lib.so"):
    with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
        lib = relay.build(run_mod, target="llvm", params=params)
    lib = update_lib(lib, lib_name=lib_name)
    exe = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
    exe.run(**input_dict)
    return [
        exe.get_output(k).asnumpy()
        for k in range(exe.get_num_outputs())
    ]

relay 定点乘法#

ishape = (1, 2)
dtype = "int32"
a = relay.var("a", relay.TensorType(ishape, dtype))
for multiplier, shift, float_value in [
    (1288490240, -2, 0.15),
    (1395864320, 1, 1.3),
    (1288490188, 0, 0.6),
]:
    fpm = relay.fixed_point_multiply(a, multiplier, shift)
    run_mod = tvm.IRModule.from_expr(fpm)
    run_mod.show()
    data_in = np.random.randint(0, 1000, size=ishape, dtype=dtype)
    inputs = {"a": data_in}
    expected_output = run_llvm(run_mod, {}, inputs, lib_name="lib.so")[0]
    print(np.round(data_in*float_value), expected_output)
def @main(%a: Tensor[(1, 2), int32]) {
  fixed_point_multiply(%a, multiplier=1288490240, shift=-2)
}
[[24.  9.]] [[24  9]]
[[120. 606.]] [[120 606]]
[[191. 480.]] [[191 480]]
def @main(%a: Tensor[(1, 2), int32]) {
  fixed_point_multiply(%a, multiplier=1395864320, shift=1)
}
def @main(%a: Tensor[(1, 2), int32]) {
  fixed_point_multiply(%a, multiplier=1288490188, shift=0)
}

relay 逐通道定点乘法#

定义用于 relay 的表达式:

from tvm.relay.op import _make
from tvm.relay.expr import Expr

def fixed_point_multiply_per_axis(
    x: Expr,
    y: Expr,
    lshift: Expr,
    rshift: Expr,
    is_lshift_required : int,
    is_rshift_required : int,
    axes,
    ):
    """Fixed point multiplication between data and a fixed point constant expressed as
    multiplier * 2^(-shift), where multiplier is a Q-number with 31 fractional bits

    Parameters
    ----------
    x : Expr
        Input argument.
    y : Expr
        Multiplier of a fixed floating point number described as multiplier*2^(-shift).
    lshift : Expr
        Left shifts of a fixed floating point number described as multiplier*2^(-shift).
    rshift : Expr
        Right shifts of a fixed floating point number described as multiplier*2^(-shift).
    is_lshift_required : int
        Whether we need to do left shift or not.
    is_rshift_required : int
        Whether we need to do right shift or not.

    Returns
    -------
    z : Expr
        The result.
    """
    return _make.fixed_point_multiply_per_axis(x, y, lshift, rshift, is_lshift_required, is_rshift_required, axes)
# from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm import te
a_shape = [2, 256, 16]
b_shape = [256]
shape = a_shape
# Verify accuracy
x_np = (
    np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32")
)
y_np = (
    np.random.randint(-1000, 1000, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
)
lsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
rsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
b_np = (
    np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32")
)
inputs = {"X":x_np, "Y": y_np, "l_shift": lsh_np, "l_shift": rsh_np}
shift_shape = [shape[1]]
x = relay.var("X", shape=shape, dtype="int32")
y = relay.var("Y", shape=shift_shape, dtype="int32")
l_shift = relay.var("l_shift", shape=shift_shape, dtype="int32")
r_shift = relay.var("r_shift", shape=shift_shape, dtype="int32")
out = fixed_point_multiply_per_axis(x, y, l_shift, r_shift, 31, 1, b_shape)
mod = tvm.IRModule.from_expr(out)
mod.show()

expected_output = run_llvm(run_mod, {}, inputs, lib_name="lib.so")[0]
print(expected_output)
def @main(%X: Tensor[(2, 256, 16), int32], %Y: Tensor[(256), int32], %l_shift: Tensor[(256), int32], %r_shift: Tensor[(256), int32]) {
  fixed_point_multiply_per_axis(%X, %Y, %l_shift, %r_shift, is_lshift_required=True, is_rshift_required=True, axes=[256])
}
[[53496115        0]]

也可以使用 QNN:

axis = 1
ishape = [1, 128, 56, 56]
in_scale_const = (1.7, 0.6)
x = relay.var("data", shape=ishape, dtype="int32")
if isinstance(in_scale_const, tuple):
    in_scale = list(in_scale_const) * (ishape[axis] // len(in_scale_const))
else:
    in_scale = [in_scale_const] * ishape[axis]
assert len(in_scale) == ishape[axis]
iscale = relay.const(in_scale)
izero = relay.const(0)
oscale = relay.const(1.0)
ozero = relay.const(0)
op = relay.qnn.op.requantize(x, iscale, izero, oscale, ozero, axis=axis, out_dtype="int32")
mod = tvm.IRModule.from_expr(op)
mod = relay.transform.InferType()(mod)
mod.show()
x_np = (
    np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32")
)
inputs = {"data": x_np}
expected_output = run_llvm(mod, {}, inputs, lib_name="lib.so")[0]
def @main(%data: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
  qnn.requantize(%data, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
from tvm import te, topi, tir as T, relay
import tvm
from tvm.topi import tag
# from tvm.relay.op.op import register_compute, register_shape_func
# from tvm.relay.op.op import register_broadcast_schedule, register_injective_schedule
# from tvm.relay.op.op import register_pattern, OpPattern

@tvm.te.tag_scope(tag=tag.ELEMWISE)
def q_multiply_shift(
    x: te.Tensor, 
    y: te.Tensor, 
    q: int, 
    left_shift: te.Tensor, 
    right_shift: te.Tensor, 
    is_left_shift_required: int):
    # Only int32 types are supported (any number of lanes is allowed)
    hp_dtype = "int64"
    lp_dtype = "int32"
    assert y.dtype == lp_dtype
    assert left_shift.dtype == lp_dtype
    assert right_shift.dtype == lp_dtype
    one = T.const(1, hp_dtype)
    def _compute(*indices):
        # 0) 获取值
        value = x(*indices)
        multiplier = y(*indices)
        ls = left_shift(*indices)
        rs = right_shift(*indices)

        # 1) Cast and Multiply the integer multiplier
        value = value.astype(hp_dtype)
        multiplier = multiplier.astype(hp_dtype)
        value = T.Select(T.const(is_left_shift_required, "bool"), 
                         value << ls, value)

        # 2) Perform the multiplication in higher precision.
        value = value * multiplier

        # 3) Find the rounding scalar
        total_right_shift = ls + q
        pos_rounding_value = (one << (total_right_shift - 1))
        value = value + pos_rounding_value

        print(total_right_shift)
        # 4) Simply right shift the result to get the final output.
        value = value >> total_right_shift
        # 5) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
        return value.astype(x.dtype)

    return te.compute(x.shape, _compute)
shape = 1, 2
lp_dtype = "int32"
x = te.placeholder(shape, name="x", dtype="int32")
y = te.placeholder(shape, name="y", dtype=lp_dtype)
left_shift = te.placeholder(shape, name="left_shift", dtype=lp_dtype)
right_shift = te.placeholder(shape, name="right_shift", dtype=lp_dtype)
# multipliers_shifts = te.placeholder(shape, name="multipliers_shifts", dtype="int32")
q = 31 # int8
# q=8 -> uint8
z = q_multiply_shift(x, y, q, left_shift, right_shift, is_left_shift_required=1)
s = te.create_schedule(z.op)
f = tvm.build(s, [x, y, left_shift, right_shift, z], "llvm")
dev = tvm.cpu(0)
a_np = np.ones(shape).astype(x.dtype) * 125333333
multiplier_np = np.ones(shape).astype(lp_dtype) * 3650000
ls_np = np.ones(shape).astype(lp_dtype) * 1
rs_np = np.ones(shape).astype(lp_dtype) * -1
a = tvm.nd.array(a_np, dev) 
multiplier = tvm.nd.array(multiplier_np, dev)
ls = tvm.nd.array(ls_np, dev)
rs = tvm.nd.array(rs_np, dev)
c = tvm.nd.array(np.zeros(shape, dtype=z.dtype), dev)
f(a, multiplier, ls, rs, c)
print(a, multiplier, c)
left_shift[i0, i1] + 31
[[125333333 125333333]] [[3650000 3650000]] [[213025 213025]]