canonicalizations

canonicalizations#

from typing import Callable

import numpy as np
from tvm import relay
from tvm.relay.qnn.op import canonicalizations


def fake_identity_func_numpy(arr: np.ndarray):
    return arr.astype("float32")

def dequantize_numpy(np_arr, np_scale=1.0, np_zero_point=0):
    return (np_arr.astype("int32") - np_zero_point) * np_scale
def fake_identity_func_relay(
    floating_point_func: Callable[[np.ndarray], np.ndarray],
    input_arg=None,
    in_scale=relay.const(1.0, dtype="float32"),
    in_zero_point=relay.const(0, dtype="int32"),
    out_scale=relay.const(1.0, dtype="float32"),
    out_zero_point=relay.const(0, dtype="int32"),
    in_axis=-1,
    out_axis=-1,
    in_dtype="uint8",
    out_dtype="uint8",
):
    if input_arg is None:
        input_arg = relay.const(np.arange(0, 256, dtype="uint8").view(in_dtype))

    return (
        canonicalizations.create_integer_lookup_op(
            input_arg=input_arg,
            floating_point_func=floating_point_func,
            in_scale=in_scale,
            in_zero_point=in_zero_point,
            out_scale=out_scale,
            out_zero_point=out_zero_point,
            in_axis=in_axis,
            out_axis=out_axis,
            in_dtype=in_dtype,
            out_dtype=out_dtype,
        ),
        input_arg.data.numpy(),
    )
def run_function_test(
    in_scale: float,
    in_zero_point: int,
    out_scale: float,
    out_zero_point: int,
    in_dtype: str,
    out_dtype: str,
    floating_point_func: Callable[[np.ndarray], np.ndarray],
    input_arg: relay.Expr = None,
    rtol=1e-7,
    atol=0,
):
    relay_lookup, input_arg = fake_identity_func_relay(
        input_arg=input_arg,
        floating_point_func=floating_point_func,
        in_scale=relay.const(in_scale, "float32"),
        in_zero_point=relay.const(in_zero_point, "int32"),
        out_scale=relay.const(out_scale, "float32"),
        out_zero_point=relay.const(out_zero_point, "int32"),
        in_dtype=in_dtype,
        out_dtype=out_dtype,
    )
    result = canonicalizations.run_const_expr(relay_lookup)
    np.testing.assert_allclose(
        floating_point_func(
            dequantize_numpy(input_arg, np_scale=in_scale, np_zero_point=in_zero_point)
        ),
        dequantize_numpy(result, np_scale=out_scale, np_zero_point=out_zero_point),
        atol=atol,
        rtol=rtol,
    )
run_function_test(
    in_scale=1.0,
    in_zero_point=0,
    out_scale=1.0,
    out_zero_point=0,
    in_dtype="int8",
    out_dtype="int8",
    floating_point_func=fake_identity_func_numpy,
)
def test_int8_to_int8(self):
    self.run_function_test(
        in_scale=1.0,
        in_zero_point=0,
        out_scale=1.0,
        out_zero_point=0,
        in_dtype="int8",
        out_dtype="int8",
        floating_point_func=self.fake_identity_func_numpy,
    )

def test_uint8_to_uint8(self):
    self.run_function_test(
        in_scale=1.0,
        in_zero_point=128,
        out_scale=1.0,
        out_zero_point=128,
        in_dtype="uint8",
        out_dtype="uint8",
        floating_point_func=self.fake_identity_func_numpy,
    )

def test_int8_to_uint8(self):
    self.run_function_test(
        in_scale=1.0,
        in_zero_point=0,
        out_scale=1.0,
        out_zero_point=128,
        in_dtype="int8",
        out_dtype="uint8",
        floating_point_func=self.fake_identity_func_numpy,
    )

def test_uint8_to_int8(self):
    self.run_function_test(
        in_scale=1.0,
        in_zero_point=128,
        out_scale=1.0,
        out_zero_point=0,
        in_dtype="uint8",
        out_dtype="int8",
        floating_point_func=self.fake_identity_func_numpy,
    )

"""Test different input shapes"""

def test_keep_input_shapes(self):
    # input in floating point ~[-2, 2], final output ~[0, 8]
    self.run_function_test(
        input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 2, 8, 8])),
        in_scale=0.015,
        in_zero_point=0,
        out_scale=16 / 256,
        out_zero_point=0,
        in_dtype="int8",
        out_dtype="int8",
        floating_point_func=self.fake_identity_func_numpy,
        atol=0.03,
        rtol=0.01,
    )
    self.run_function_test(
        input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 2, 64])),
        in_scale=0.015,
        in_zero_point=0,
        out_scale=16 / 256,
        out_zero_point=0,
        in_dtype="int8",
        out_dtype="int8",
        floating_point_func=self.fake_identity_func_numpy,
        atol=0.03,
        rtol=0.01,
    )
    self.run_function_test(
        input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 128])),
        in_scale=0.015,
        in_zero_point=0,
        out_scale=16 / 256,
        out_zero_point=0,
        in_dtype="int8",
        out_dtype="int8",
        floating_point_func=self.fake_identity_func_numpy,
        atol=0.03,
        rtol=0.01,
    )