自定义 VTA Graph Pack#

from copy import deepcopy
import tvm
from tvm import relay
from vta_utils.pack_tool import graph_pack, WithVTAFunctionTransform

VTA 模型样例#

from torch import nn
import torch

class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 36, 3, 1, 1, bias=True)
        self.bn = nn.BatchNorm2d(36)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

pt_model = Model().eval().float()
ishape = (1, 3, 4, 4)
input_name = "data"
input_shapes = [(input_name, ishape)]
# script_module = torch.jit.script(pt_model)
# mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
idata = torch.randn(ishape).type(torch.float32)
traced_model = torch.jit.trace(pt_model, idata)
# traced_model 翻译为 TVM 前端模型
mod, params = relay.frontend.from_pytorch(traced_model, input_shapes)
# 量化
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(skip_conv_layers=[], weight_scale="max",):
        mod = relay.quantize.quantize(mod, params)
mod.show()
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 36, 4, 4), float32] {
  %0 = multiply(%data, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(36, 3, 3, 3), int8] */, padding=[1, 1, 1, 1], channels=36, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 36, 4, 4), int32] */;
  %5 = add(%4, meta[relay.Constant][1] /* ty=Tensor[(36, 1, 1), int32] */) /* ty=Tensor[(1, 36, 4, 4), int32] */;
  %6 = fixed_point_multiply(%5, multiplier=0, shift=0) /* ty=Tensor[(1, 36, 4, 4), int32] */;
  %7 = cast(%6, dtype="int32") /* ty=Tensor[(1, 36, 4, 4), int32] */;
  %8 = add(%7, meta[relay.Constant][2] /* ty=Tensor[(36, 1, 1), int32] */) /* ty=Tensor[(1, 36, 4, 4), int32] */;
  %9 = nn.relu(%8) /* ty=Tensor[(1, 36, 4, 4), int32] */;
  %10 = cast(%9, dtype="int64") /* ty=Tensor[(1, 36, 4, 4), int64] */;
  %11 = fixed_point_multiply(%10, multiplier=0, shift=0) /* ty=Tensor[(1, 36, 4, 4), int64] */;
  %12 = clip(%11, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 36, 4, 4), int64] */;
  %13 = cast(%12, dtype="int32") /* ty=Tensor[(1, 36, 4, 4), int32] */;
  %14 = cast(%13, dtype="int8") /* ty=Tensor[(1, 36, 4, 4), int8] */;
  %15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 36, 4, 4), int8] */;
  %16 = cast(%15, dtype="float32") /* ty=Tensor[(1, 36, 4, 4), float32] */;
  multiply(%16, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 36, 4, 4), float32] */
}

VTA Graph Pack#

from tvm.relay.function import Function
from tvm.relay.testing import run_opt_pass
import vta

env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH

run_mod = deepcopy(mod)
new_fn = graph_pack(
    run_mod["main"], 
    bfactor, cfactor, weight_bits
)
tvm.IRModule.from_expr(new_fn).show()
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 3, 4, 4, 1, 16), float32] {
  %0 = multiply(%data, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */;
  %4 = nn.pad(%3, 0 /* ty=int32 */, pad_width=[[0, 0], [0, 13], [0, 0], [0, 0]]) /* ty=Tensor[(1, 16, 4, 4), int8] */;
  %5 = reshape(%4, newshape=[1, 1, 1, 16, 4, 4]) /* ty=Tensor[(1, 1, 1, 16, 4, 4), int8] */;
  %6 = nn.pad(meta[relay.Constant][0] /* ty=Tensor[(36, 3, 3, 3), int8] */, 0 /* ty=int32 */, pad_width=[[0, 12], [0, 13], [0, 0], [0, 0]]) /* ty=Tensor[(48, 16, 3, 3), int8] */;
  %7 = reshape(%6, newshape=[3, 16, 1, 16, 3, 3]) /* ty=Tensor[(3, 16, 1, 16, 3, 3), int8] */;
  %8 = transpose(%5, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */;
  %9 = transpose(%7, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(3, 1, 3, 3, 16, 16), int8] */;
  %10 = nn.pad(meta[relay.Constant][1] /* ty=Tensor[(36, 1, 1), int32] */, 0 /* ty=int32 */, pad_width=[[0, 12], [0, 0], [0, 0]]) /* ty=Tensor[(48, 1, 1), int32] */;
  %11 = reshape(%10, newshape=[3, 16, 1, 1, 1]) /* ty=Tensor[(3, 16, 1, 1, 1), int32] */;
  %12 = transpose(%11, axes=[0, 2, 3, 4, 1]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
  %13 = nn.conv2d(%8, %9, padding=[1, 1, 1, 1], channels=48, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %14 = broadcast_to(%12, shape=[3, 1, 1, 1, 16]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
  %15 = add(%13, %14) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %16 = fixed_point_multiply(%15, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %17 = nn.pad(meta[relay.Constant][2] /* ty=Tensor[(36, 1, 1), int32] */, 0 /* ty=int32 */, pad_width=[[0, 12], [0, 0], [0, 0]]) /* ty=Tensor[(48, 1, 1), int32] */;
  %18 = reshape(%17, newshape=[3, 16, 1, 1, 1]) /* ty=Tensor[(3, 16, 1, 1, 1), int32] */;
  %19 = transpose(%18, axes=[0, 2, 3, 4, 1]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
  %20 = cast(%16, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %21 = broadcast_to(%19, shape=[3, 1, 1, 1, 16]) /* ty=Tensor[(3, 1, 1, 1, 16), int32] */;
  %22 = add(%20, %21) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %23 = nn.relu(%22) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %24 = cast(%23, dtype="int64") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
  %25 = fixed_point_multiply(%24, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
  %26 = clip(%25, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
  %27 = cast(%26, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %28 = cast(%27, dtype="int8") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
  %29 = annotation.stop_fusion(%28) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
  %30 = cast(%29, dtype="float32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */;
  multiply(%30, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */
}

VTA 模型的算子融合#

创建融合策略:

from vta_utils.vta_pattern import (
    preprocessing_pattern,
    pad_reshape_transpose_pattern,
    conv_add_activate_pattern,
    output_pattern,
)
pattern_table = [
    ("vta_preprocessing", preprocessing_pattern()),
    ("vta_reshape_transpose", pad_reshape_transpose_pattern()),
    ("vta_conv2d", conv_add_activate_pattern()),
    ("vta_output", output_pattern()),
]

实现算子融合:

import vta

env = vta.get_env()
bfactor = env.BATCH
cfactor = env.BLOCK_OUT
weight_bits = env.WGT_WIDTH

prepare_transform = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.FoldConstant(), # 折叠常量参数
    relay.transform.MergeComposite(pattern_table), # 算子融合
    WithVTAFunctionTransform(), # 为融合函数 vta_conv2d 添加 ConvAttrs 属性
    relay.transform.InferType(),
])

run_mod = deepcopy(mod)

with tvm.transform.PassContext(opt_level=3):
    new_fn = graph_pack(run_mod["main"], bfactor, cfactor, weight_bits)
    run_mod = tvm.IRModule.from_expr(new_fn)
    run_mod = prepare_transform(run_mod)
run_mod.show()
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 3, 4, 4, 1, 16), float32] {
  %0 = @vta_preprocessing__0(%data) /* ty=Tensor[(1, 3, 4, 4), int8] */;
  %1 = @vta_reshape_transpose__1(%0, 0 /* ty=int32 */) /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */;
  %2 = @vta_conv2d__2(%1, meta[relay.Constant][0] /* ty=Tensor[(3, 1, 3, 3, 16, 16), int8] */, meta[relay.Constant][1] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */, meta[relay.Constant][2] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
  @vta_output__3(%2) /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */
}

def @vta_conv2d__2(%FunctionVar_0_0: Tensor[(1, 1, 4, 4, 1, 16), int8] /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */, %FunctionVar_0_1: Tensor[(3, 1, 3, 3, 16, 16), int8] /* ty=Tensor[(3, 1, 3, 3, 16, 16), int8] */, %FunctionVar_0_2: Tensor[(3, 1, 1, 1, 16), int32] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */, %FunctionVar_0_3: Tensor[(3, 1, 1, 1, 16), int32] /* ty=Tensor[(3, 1, 1, 1, 16), int32] */, PartitionedFromPattern="nn.conv2d_add_fixed_point_multiply_cast_add_nn.relu_cast_fixed_point_multiply_clip_cast_cast_annotation.stop_fusion_", Composite="vta_conv2d", ConvAttrs={padding=[1, 1, 1, 1], channels=48, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32"}) -> Tensor[(1, 3, 4, 4, 1, 16), int8] {
  %3 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=48, kernel_size=[3, 3], data_layout="NCHW1n16c", kernel_layout="OIHW16o16i", out_dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %4 = add(%3, %FunctionVar_0_2) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %5 = fixed_point_multiply(%4, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %6 = cast(%5, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %7 = add(%6, %FunctionVar_0_3) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %8 = nn.relu(%7) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %9 = cast(%8, dtype="int64") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
  %10 = fixed_point_multiply(%9, multiplier=0, shift=0) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
  %11 = clip(%10, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int64] */;
  %12 = cast(%11, dtype="int32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int32] */;
  %13 = cast(%12, dtype="int8") /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */;
  annotation.stop_fusion(%13) /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */
}

def @vta_output__3(%FunctionVar_0_01: Tensor[(1, 3, 4, 4, 1, 16), int8] /* ty=Tensor[(1, 3, 4, 4, 1, 16), int8] */, PartitionedFromPattern="cast_multiply_", Composite="vta_output") -> Tensor[(1, 3, 4, 4, 1, 16), float32] {
  %14 = cast(%FunctionVar_0_01, dtype="float32") /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */;
  multiply(%14, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4, 1, 16), float32] */
}

def @vta_preprocessing__0(%FunctionVar_0_02: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] */, PartitionedFromPattern="multiply_round_clip_cast_", Composite="vta_preprocessing") -> Tensor[(1, 3, 4, 4), int8] {
  %15 = multiply(%FunctionVar_0_02, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %16 = round(%15) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %17 = clip(%16, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  cast(%17, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */
}

def @vta_reshape_transpose__1(%FunctionVar_0_03: Tensor[(1, 3, 4, 4), int8] /* ty=Tensor[(1, 3, 4, 4), int8] */, %FunctionVar_0_11: int32 /* ty=int32 */, PartitionedFromPattern="nn.pad_reshape_transpose_", Composite="vta_reshape_transpose") -> Tensor[(1, 1, 4, 4, 1, 16), int8] {
  %18 = nn.pad(%FunctionVar_0_03, %FunctionVar_0_11, pad_width=[[0, 0], [0, 13], [0, 0], [0, 0]]) /* ty=Tensor[(1, 16, 4, 4), int8] */;
  %19 = reshape(%18, newshape=[1, 1, 1, 16, 4, 4]) /* ty=Tensor[(1, 1, 1, 16, 4, 4), int8] */;
  transpose(%19, axes=[0, 2, 4, 5, 1, 3]) /* ty=Tensor[(1, 1, 4, 4, 1, 16), int8] */
}