解构 TVM 量化

解构 TVM 量化#

import logging
import set_env
from d2py.utils.log_config import config_logging
from d2py.utils.file import mkdir
# 配置日志信息
temp_dir = ".temp"
logger_name = "parse"
mkdir(temp_dir)
config_logging(
    f"{temp_dir}/{logger_name}.log", logger_name, 
    filemode="w", filter_mod_names={"te_compiler"}
)
logger = logging.getLogger(logger_name)

加载模块:

import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform as _transform
from tvm.relay import expr as _expr
from tvm.relay import Call, Constant, Function
from tvm.ir.op import Op
from tvm.relay import op as _op
from tvm.relay import expr as _expr
from tvm_book.tvm_utils.llvm_utils import run_llvm_graph

定义简单网络:

def load_model(input_shape=[1, 3, 224, 224]):
    """加载前端模型"""
    import torch
    from torchvision.models import resnet18
    from torchvision.models.resnet import ResNet18_Weights
    model = resnet18(weights=ResNet18_Weights.DEFAULT)
    data = torch.randn(*input_shape)
    return torch.jit.trace(model.eval(), data)

size = 224, 224
input_shape = (1, 3, *size)
input_name = "data"
traced_model = load_model(input_shape).eval()
# 将前端模型翻译为 relay 模型
origin_mod, params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])

先解构 resnet18 第一个计算块:

mod = relay.analysis.extract_intermdeiate_expr(origin_mod, 3)
mod = _transform.InferType()(mod)

转换前端模型为 relay 模型:

def _bind_params(func, params):
    """将 params 绑定到 func"""
    name_dict = {}
    for arg in func.params:
        name = arg.name_hint
        if name in name_dict:
            name_dict[name] = None
        else:
            name_dict[name] = arg
    bind_dict = {}
    for k, v in params.items():
        if k not in name_dict:
            continue
        arg = name_dict[k]
        if arg is None:
            raise ValueError(f"Multiple args in the function have name {k}")
        bind_dict[arg] = _expr.const(v)
    return _expr.bind(func, bind_dict)

print('原始模型:')
mod.show()
# 将 params 绑定到 origin_mod
if params:
    mod["main"] = _bind_params(mod["main"], params)
print('原始模型(绑定参数):')
mod.show()
# 化简并折叠常量
optimize = tvm.transform.Sequential([
    _transform.SimplifyInference(),
    _transform.FoldConstant(),
    _transform.FoldScaleAxis(),
    _transform.CanonicalizeOps(),
    _transform.FoldConstant(),
])
with tvm.transform.PassContext(opt_level=3):
    run_mod = optimize(mod)
print('原始模型(化简后):')
run_mod.show()
原始模型:
原始模型(绑定参数):
原始模型(化简后):
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */, %aten::_convolution_0.weight: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] span=aten::_convolution_0.weight:0:0 */, %aten::batch_norm_0.weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.weight:0:0 */, %aten::batch_norm_0.bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.bias:0:0 */, %aten::batch_norm_0.running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_mean:0:0 */, %aten::batch_norm_0.running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=aten::batch_norm_0.running_var:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
  %0 = nn.conv2d(%data, %aten::_convolution_0.weight, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::_convolution_0:0:0 */;
  %1 = nn.batch_norm(%0, %aten::batch_norm_0.weight, %aten::batch_norm_0.bias, %aten::batch_norm_0.running_mean, %aten::batch_norm_0.running_var) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) span=aten::batch_norm_0:0:0 */;
  %2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0:0:0 */;
  nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
}
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0], strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::_convolution_0:0:0 */;
  %1 = nn.batch_norm(%0, meta[relay.Constant][1], meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4]) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) span=aten::batch_norm_0:0:0 */;
  %2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::batch_norm_0:0:0 */;
  nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
}
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 112, 112), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  nn.relu(%1) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
}

查看化简前后卷积参数变化:

class _Transform(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.binds = {}
        self.func_id = 0
    def visit_call(self, call):
        new_fn = self.visit(call.op)
        new_args = [self.visit(arg) for arg in call.args]
        call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
        if isinstance(new_fn, Op):
            if new_fn.name == "nn.conv2d":
                self.binds[f"{new_fn.name}_{self.func_id}"] = new_args[1]
                self.func_id += 1
        return call

transform = _Transform()
transform.visit(mod["main"])
weight_ori = transform.binds['nn.conv2d_0']
transform = _Transform()
transform.visit(run_mod["main"])
weight = transform.binds['nn.conv2d_0']
weight_ori.data.numpy()[0, 0, :5, :5]
array([[-0.01041935, -0.00613561, -0.00180978,  0.07484142,  0.05661485],
       [ 0.01108271,  0.00952757, -0.10992692, -0.28050068, -0.27123755],
       [-0.00694335,  0.05908897,  0.29548222,  0.587196  ,  0.5197189 ],
       [ 0.03050456, -0.06701802, -0.29841137, -0.4386757 , -0.27085286],
       [-0.02753477,  0.01604508,  0.07259498, -0.05410165, -0.33284944]],
      dtype=float32)
weight.data.numpy()[0, 0, :5, :5]
array([[-0.00242674, -0.00142902, -0.00042151,  0.01743106,  0.01318597],
       [ 0.00258124,  0.00221903, -0.0256027 , -0.06533046, -0.06317302],
       [-0.00161715,  0.01376221,  0.06881975,  0.13676181,  0.12104595],
       [ 0.00710471, -0.01560894, -0.06950197, -0.10217047, -0.06308342],
       [-0.00641303,  0.00373701,  0.01690785, -0.01260063, -0.07752283]],
      dtype=float32)