partition_conversions

partition_conversions#

tvm.relay.quantize._partition_conversions.partition_conversions() 将模块划分为输入量化、核心量化推理和输出反量化。

import numpy as np
import tvm
from tvm.runtime.vm import VirtualMachine
from tvm import relay
from torch import nn
import torch

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=True)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

def create_model(ishape = (1, 3, 4, 4)):
    pt_model = Model().eval().float()
    input_shapes = [("data", ishape)]
    # script_module = torch.jit.script(pt_model)
    # mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
    idata = torch.rand(ishape).type(torch.float32)
    traced_model = torch.jit.trace(pt_model, idata)
    # traced_model 翻译为 TVM 前端模型
    mod, params = relay.frontend.from_pytorch(traced_model, input_shapes, 
                                              use_parser_friendly_name=True)
    return mod, params
print(f"修改前量化配置:\n{relay.quantize.current_qconfig()}")
mod, params = create_model(ishape = (1, 3, 4, 4))
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        do_simulation=True
    ):
        print(f"当前量化配置:\n{relay.quantize.current_qconfig()}\n")
        qmod = relay.quantize.quantize(mod, params)
print(qmod)
修改前量化配置:
qconfig(nbit_input=8, nbit_weight=8, nbit_activation=32, calibrate_mode=global_scale, global_scale=8, weight_scale=power2, skip_conv_layers==(nullptr), skip_dense_layer==1, do_simulation==0, round_for_shift==1, debug_enabled_ops==(nullptr), rounding==UPWARD, partition_conversions==disabled)
当前量化配置:
qconfig(nbit_input=8, nbit_weight=8, nbit_activation=32, calibrate_mode=global_scale, global_scale=8, weight_scale=power2, skip_conv_layers==[], skip_dense_layer==1, do_simulation==1, round_for_shift==1, debug_enabled_ops==(nullptr), rounding==UPWARD, partition_conversions==disabled)

def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, 0.0625f /* ty=float32 */, -127f /* ty=float32 */, 127f /* ty=float32 */, kind=1) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %1 = nn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %5 = relay.op.annotation.simulated_quantize(%4, 0.0625f /* ty=float32 */, -127f /* ty=float32 */, 127f /* ty=float32 */, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = annotation.cast_hint(%5, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 4, 4), float32] */
}
mod, params = create_model(ishape = (1, 3, 4, 4))
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        partition_conversions="enabled",
        do_simulation=False
    ):
        print(f"当前量化配置:\n{relay.quantize.current_qconfig()}\n")
        qmod = relay.quantize.quantize(mod, params)
当前量化配置:
qconfig(nbit_input=8, nbit_weight=8, nbit_activation=32, calibrate_mode=global_scale, global_scale=8, weight_scale=power2, skip_conv_layers==[], skip_dense_layer==1, do_simulation==0, round_for_shift==1, debug_enabled_ops==(nullptr), rounding==UPWARD, partition_conversions==enabled)
print(qmod)
def @dequantize_outputs(%input: Tensor[(1, 16, 4, 4), int8] /* ty=Tensor[(1, 16, 4, 4), int8] */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = cast(%input, dtype="float32") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  multiply(%0, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 16, 4, 4), float32] */
}

def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] */) -> Tensor[(1, 16, 4, 4), float32] {
  let %quantized_inputs: (Tensor[(1, 3, 4, 4), int8],) /* ty=(Tensor[(1, 3, 4, 4), int8],) */ = @quantize_inputs(%data) /* ty=(Tensor[(1, 3, 4, 4), int8],) */;
  %1 = %quantized_inputs.0 /* ty=Tensor[(1, 3, 4, 4), int8] */;
  let %quantized_outputs: Tensor[(1, 16, 4, 4), int8] /* ty=Tensor[(1, 16, 4, 4), int8] */ = @quantized_main(%1) /* ty=Tensor[(1, 16, 4, 4), int8] */;
  let %dequantized_outputs: Tensor[(1, 16, 4, 4), float32] /* ty=Tensor[(1, 16, 4, 4), float32] */ = @dequantize_outputs(%quantized_outputs) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %dequantized_outputs
}

def @quantize_inputs(%data1: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> (Tensor[(1, 3, 4, 4), int8],) {
  %2 = multiply(%data1, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %3 = round(%2) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %4 = clip(%3, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  let %data2: Tensor[(1, 3, 4, 4), int8] /* ty=Tensor[(1, 3, 4, 4), int8] */ = cast(%4, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */;
  (%data2,) /* ty=(Tensor[(1, 3, 4, 4), int8],) */
}

def @quantized_main(%data3: Tensor[(1, 3, 4, 4), int8] /* ty=Tensor[(1, 3, 4, 4), int8] */) -> Tensor[(1, 16, 4, 4), int8] {
  %5 = nn.conv2d(%data3, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), int8] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
  %6 = add(%5, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), int32] */) /* ty=Tensor[(1, 16, 4, 4), int32] */;
  %7 = add(%6, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), int32] */) /* ty=Tensor[(1, 16, 4, 4), int32] */;
  %8 = nn.relu(%7) /* ty=Tensor[(1, 16, 4, 4), int32] */;
  %9 = add(%8, 256 /* ty=int32 */) /* ty=Tensor[(1, 16, 4, 4), int32] */;
  %10 = right_shift(%9, 9 /* ty=int32 */) /* ty=Tensor[(1, 16, 4, 4), int32] */;
  %11 = clip(%10, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 4, 4), int32] */;
  %12 = cast(%11, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), int8] */;
  annotation.stop_fusion(%12) /* ty=Tensor[(1, 16, 4, 4), int8] */
}
dev = tvm.cpu()
data_np = np.random.uniform(low=-1, high=1, size=[1, 3, 4, 4]).astype("float32")
input_dict = {"data": data_np}

with tvm.transform.PassContext(opt_level=3):
    qvm_exec = relay.vm.compile(qmod, target="llvm", params=params)
qvm = VirtualMachine(qvm_exec, dev)
qvm.set_input("main", **input_dict)
tvm_qres = qvm.run()
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.