预热阶段

预热阶段#

from tvm.relay.quantize.quantize import prerequisite_optimize
prerequisite_optimize??
Signature: prerequisite_optimize(mod, params=None)
Source:   
def prerequisite_optimize(mod, params=None):
    """Prerequisite optimization passes for quantization. Perform
    "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
    "CanonicalizeOps" optimization before quantization."""
    optimize = tvm.transform.Sequential(
        [
            _transform.SimplifyInference(),
            _transform.FoldConstant(),
            _transform.FoldScaleAxis(),
            _transform.CanonicalizeOps(),
            _transform.FoldConstant(),
        ]
    )

    if params:
        mod["main"] = _bind_params(mod["main"], params)

    mod = optimize(mod)
    return mod
File:      /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/relay/quantize/quantize.py
Type:      function

以下列模型为例展示:

import numpy as np
import tvm
from tvm.runtime.vm import VirtualMachine
from tvm import relay
from torch import nn
import torch

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=True)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

def create_model(ishape = (1, 3, 4, 4)):
    pt_model = Model().eval().float()
    input_shapes = [("data", ishape)]
    # script_module = torch.jit.script(pt_model)
    # mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
    idata = torch.rand(ishape).type(torch.float32)
    traced_model = torch.jit.trace(pt_model, idata)
    # traced_model 翻译为 TVM 前端模型
    mod, params = relay.frontend.from_pytorch(traced_model, input_shapes, 
                                              use_parser_friendly_name=True)
    return mod, params

mod, params = create_model(ishape = (1, 3, 4, 4))
print(mod["main"])
fn (%data: Tensor[(1, 3, 4, 4), float32] /* span=aten___convolution_0_data:0:0 */, %aten___convolution_0_weight: Tensor[(16, 3, 3, 3), float32] /* span=aten___convolution_0_weight:0:0 */, %aten___convolution_0_bias: Tensor[(16), float32] /* span=aten___convolution_0_bias:0:0 */, %aten__batch_norm_0_weight: Tensor[(16), float32] /* span=aten__batch_norm_0_weight:0:0 */, %aten__batch_norm_0_bias: Tensor[(16), float32] /* span=aten__batch_norm_0_bias:0:0 */, %aten__batch_norm_0_mean: Tensor[(16), float32] /* span=aten__batch_norm_0_mean:0:0 */, %aten__batch_norm_0_var: Tensor[(16), float32] /* span=aten__batch_norm_0_var:0:0 */) {
  %0 = nn.conv2d(%data, %aten___convolution_0_weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* span=aten___convolution_0:0:0 */;
  %1 = nn.bias_add(%0, %aten___convolution_0_bias) /* span=aten___convolution_0:0:0 */;
  %2 = nn.batch_norm(%1, %aten__batch_norm_0_weight, %aten__batch_norm_0_bias, %aten__batch_norm_0_mean, %aten__batch_norm_0_var) /* span=aten__batch_norm_0:0:0 */;
  %3 = %2.0 /* span=aten__batch_norm_0:0:0 */;
  nn.relu(%3) /* span=aten__relu_0:0:0 */
}

参数绑定#

from tvm.relay.quantize.quantize import _bind_params

print(f"绑定参数前:\n{mod['main']}")
print("="*50)
mod["main"] = _bind_params(mod["main"], params)
print(f"绑定参数后:\n{mod['main']}")
绑定参数前:
fn (%data: Tensor[(1, 3, 4, 4), float32] /* span=aten___convolution_0_data:0:0 */, %aten___convolution_0_weight: Tensor[(16, 3, 3, 3), float32] /* span=aten___convolution_0_weight:0:0 */, %aten___convolution_0_bias: Tensor[(16), float32] /* span=aten___convolution_0_bias:0:0 */, %aten__batch_norm_0_weight: Tensor[(16), float32] /* span=aten__batch_norm_0_weight:0:0 */, %aten__batch_norm_0_bias: Tensor[(16), float32] /* span=aten__batch_norm_0_bias:0:0 */, %aten__batch_norm_0_mean: Tensor[(16), float32] /* span=aten__batch_norm_0_mean:0:0 */, %aten__batch_norm_0_var: Tensor[(16), float32] /* span=aten__batch_norm_0_var:0:0 */) {
  %0 = nn.conv2d(%data, %aten___convolution_0_weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* span=aten___convolution_0:0:0 */;
  %1 = nn.bias_add(%0, %aten___convolution_0_bias) /* span=aten___convolution_0:0:0 */;
  %2 = nn.batch_norm(%1, %aten__batch_norm_0_weight, %aten__batch_norm_0_bias, %aten__batch_norm_0_mean, %aten__batch_norm_0_var) /* span=aten__batch_norm_0:0:0 */;
  %3 = %2.0 /* span=aten__batch_norm_0:0:0 */;
  nn.relu(%3) /* span=aten__relu_0:0:0 */
}
==================================================
绑定参数后:
fn (%data: Tensor[(1, 3, 4, 4), float32] /* span=aten___convolution_0_data:0:0 */) {
  %0 = nn.conv2d(%data, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* span=aten___convolution_0:0:0 */;
  %1 = nn.bias_add(%0, meta[relay.Constant][1]) /* span=aten___convolution_0:0:0 */;
  %2 = nn.batch_norm(%1, meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4], meta[relay.Constant][5]) /* span=aten__batch_norm_0:0:0 */;
  %3 = %2.0 /* span=aten__batch_norm_0:0:0 */;
  nn.relu(%3) /* span=aten__relu_0:0:0 */
}

模型简化#

optimize = tvm.transform.Sequential(
    [
        relay.transform.SimplifyInference(),
        relay.transform.FoldConstant(),
        relay.transform.FoldScaleAxis(),
        relay.transform.CanonicalizeOps(),
        relay.transform.FoldConstant(),
    ]
)
with tvm.transform.PassContext(opt_level=3):
    mod = optimize(mod)
print(mod["main"])
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  nn.relu(%2) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */