量化注解

量化注解#

from torch import nn
import torch

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(16, 16, 3, 1, 1, bias=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x1 = self.relu(x)
        x = self.conv2(x)
        x2 = self.relu(x)
        x = x1 + x2
        return x
import set_env
import numpy as np
import tvm
from tvm import relay

# 输入数据
input_shape = (1, 3, 4, 4)
input_dtype = "float32"
data_np = np.random.rand(*input_shape).astype(input_dtype)
with torch.no_grad():
    pt_model = Model().eval().float()
    traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()
mod, params = relay.frontend.from_pytorch(traced_model, [("data", input_shape)], 
                                          use_parser_friendly_name=True)
with tvm.transform.PassContext(opt_level=3):
    mod = relay.quantize.prerequisite_optimize(mod, params)
print(mod['main'])
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %4 = nn.relu(%2) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  add(%3, %4) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */
relay.quantize.partition()(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.cast_hint(%0, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = annotation.stop_fusion(%3) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %5 = nn.conv2d(%4, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %6 = add(%5, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %7 = nn.relu(%6) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %9 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %10 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;
  %12 = annotation.cast_hint(%11, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  annotation.stop_fusion(%12) /* ty=Tensor[(1, 16, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */
passes = tvm.transform.Sequential([
    relay.quantize.partition(),
    relay.quantize.annotate()
])
passes(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */, %dom_scale: float32 /* ty=float32 */, %clip_min: float32 /* ty=float32 */, %clip_max: float32 /* ty=float32 */, %dom_scale1: float32 /* ty=float32 */, %clip_min1: float32 /* ty=float32 */, %clip_max1: float32 /* ty=float32 */, %dom_scale2: float32 /* ty=float32 */, %clip_min2: float32 /* ty=float32 */, %clip_max2: float32 /* ty=float32 */, %dom_scale3: float32 /* ty=float32 */, %clip_min3: float32 /* ty=float32 */, %clip_max3: float32 /* ty=float32 */, %dom_scale4: float32 /* ty=float32 */, %clip_min4: float32 /* ty=float32 */, %clip_max4: float32 /* ty=float32 */, %dom_scale5: float32 /* ty=float32 */, %clip_min5: float32 /* ty=float32 */, %clip_max5: float32 /* ty=float32 */, %dom_scale6: float32 /* ty=float32 */, %clip_min6: float32 /* ty=float32 */, %clip_max6: float32 /* ty=float32 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, %dom_scale, %clip_min, %clip_max, kind=1) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, %dom_scale1, %clip_min1, %clip_max1, kind=2) /* ty=Tensor[(16, 3, 3, 3), float32] */;
  %2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %3 = relay.op.annotation.simulated_quantize(%2, %dom_scale2, %clip_min2, %clip_max2, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %5 = annotation.cast_hint(%4, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = annotation.cast_hint(%3, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %8 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, %dom_scale3, %clip_min3, %clip_max3, kind=2) /* ty=Tensor[(16, 16, 3, 3), float32] */;
  %9 = nn.conv2d(%7, %8, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %10 = relay.op.annotation.simulated_quantize(meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */, %dom_scale4, %clip_min4, %clip_max4, kind=2) /* ty=Tensor[(16, 1, 1), float32] */;
  %11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %12 = nn.relu(%11) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %13 = relay.op.annotation.simulated_quantize(%12, %dom_scale5, %clip_min5, %clip_max5, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %15 = annotation.stop_fusion(%5) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %16 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %17 = add(%15, %16) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;
  %18 = relay.op.annotation.simulated_quantize(%17, %dom_scale6, %clip_min6, %clip_max6, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  annotation.stop_fusion(%19) /* ty=Tensor[(1, 16, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32], float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32) -> Tensor[(1, 16, 4, 4), float32] */