# 自定义 TVM 自动量化

采用模板匹配的策略进行量化分区。

```{admonition} 显式定义融合规则的好处
1. 更加精细控制融合算子的规则
2. 更好的适配诸如 VTA 等后端。
```

In [1]:
import set_env

In [2]:
import numpy as np
from tvm import relay
import tvm
from tvm_book.tvm_utils.llvm_utils import run_llvm_graph

def load_model(input_shape=[1, 3, 224, 224]):
    """加载前端模型"""
    import torch
    from torchvision.models import resnet18
    from torchvision.models.resnet import ResNet18_Weights
    model = resnet18(weights=ResNet18_Weights.DEFAULT)
    data = torch.randn(*input_shape)
    return torch.jit.trace(model.eval(), data)

size = 224, 224
input_shape = (1, 3, *size)
input_name = "data"
traced_model = load_model(input_shape).eval()
# 将前端模型翻译为 relay 模型
origin_mod, origin_params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])

先以 `mod` 子图为例研究定义量化过程：

In [3]:
mod = relay.analysis.extract_intermdeiate_expr(origin_mod, 12)

此时的 `mod` 存在 `nn.batch_norm` 算子以及常量表达式：

In [4]:
mod.show()

运行如下代码便可将 `nn.batch_norm` 进行融合，同时将其模型参数替换掉常量表达式（还有一些其他操作，此时不展开了）：

In [5]:
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        # 量化前准备
        run_mod = relay.quantize.prerequisite_optimize(mod, origin_params)

In [6]:
run_mod.show()

## 定义融合规则

想要融合 `conv2d+add+relu` 结构，可以定义融合函数：

In [7]:
from tvm.relay.dataflow_pattern import is_op, wildcard

def make_conv_add_relu_pattern():
    """创建如下模式

     conv2d
        |
      (add)
        |
      (relu)
    """
    x = wildcard()
    w = wildcard()
    bias = wildcard()
    r = is_op("nn.conv2d")(x, w)
    r = is_op("add")(r, bias) | r # bias 是可选的
    # 激活函数
    r = is_op("nn.relu")(r) | r # 激活函数也是可选的
    return r

上述结构模式可以用来匹配 `conv2d`、`conv2d+add`、`conv2d+add+relu`、`conv2d+relu` 四种模式。

执行融合：

In [8]:
compiler_name = "ccompiler"
pattern_table = [
    (f"{compiler_name}.conv_add_relu", make_conv_add_relu_pattern()),
]
merge_passes = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table),
    # # relay.transform.AnnotateTarget([compiler_name]),
    # relay.transform.PartitionGraph(),
])
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        run_mod_f = merge_passes(run_mod)
print(run_mod_f)

def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
  %5 = fn (%FunctionVar_2_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %FunctionVar_2_1: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %FunctionVar_2_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 112, 112), float32] {
    %3 = nn.conv2d(%FunctionVar_2_0, %FunctionVar_2_1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
    %4 = add(%3, %FunctionVar_2_2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
    nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
  } /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(64, 3, 7, 7), float32], Tensor[(64

可以看出，上述剩余 `nn.max_pool2d` 和残差 `add` 没有被融合，故此可以添加规则：

In [9]:
def make_max_pool2d_pattern():
    x = wildcard()
    r = is_op("nn.max_pool2d")(x)
    return r

def make_add_pattern():
    return wildcard() + wildcard()

compiler_name = "ccompiler"
# 按照顺序依次执行匹配工作
pattern_table = [
    (f"{compiler_name}.conv_add_relu", make_conv_add_relu_pattern()),
    (f"{compiler_name}.max_pool2d", make_max_pool2d_pattern()),
    (f"{compiler_name}.add", make_add_pattern()),
]
merge_passes = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table),
    # # relay.transform.AnnotateTarget([compiler_name]),
    relay.transform.PartitionGraph(),
])
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        run_mod_f = merge_passes(run_mod)
print(run_mod_f)

def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {
  %5 = fn (%FunctionVar_2_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %FunctionVar_2_1: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %FunctionVar_2_2: Tensor[(64, 1, 1), float32] /* ty=Tensor[(64, 1, 1), float32] */, PartitionedFromPattern="nn.conv2d_add_nn.relu_", Composite="ccompiler.conv_add_relu") -> Tensor[(1, 64, 112, 112), float32] {
    %3 = nn.conv2d(%FunctionVar_2_0, %FunctionVar_2_1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
    %4 = add(%3, %FunctionVar_2_2) /* ty=Tensor[(1, 64, 112, 112), float32] */;
    nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */
  } /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(64, 3, 7, 7), float32], Tensor[(64

符合期望结构。

## 为融合函数添加 `QPartitionExpr` 算子

In [10]:
from tvm.relay import Call
from tvm.relay.function import Function, FunctionWithFields
from tvm.relay.quantize._partition import QPartitionExpr

@tvm.relay.transform.function_pass(opt_level=1)
class MergeGraphTransform:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.nodes = []

    def transform_function(self, func, mod, ctx):
        obj = self
        class Replace(tvm.relay.ExprMutator):
            def visit_function(self, fn):
                new_params = [self.visit(x) for x in fn.params]
                new_body = self.visit(fn.body)
                if not isinstance(new_body.op, Function): # 防止循环添加 QPartitionExpr
                    new_body = QPartitionExpr(new_body).realize()
                if new_params == list(fn.params) and new_body == fn.body:
                    new_fn =  fn
                else:
                    new_fn = FunctionWithFields(fn, list(new_params), new_body)
                obj.nodes.append(new_fn)
                return new_fn
        return Replace().visit(func)

In [11]:
transform = MergeGraphTransform()
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        mod_sq = transform(run_mod_f)
mod_sq.show()

## 消除计算图中的函数表达式

由于 {class}`tvm.contrib.graph_executor.GraphModule` 不支持对 {class}`tvm.relay.function.Function` 进行推理，需要分解其为原语函数，以支持后续的校准过程：

In [12]:
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        run_mod_sq = relay.transform.DefuseOps()(mod_sq)
run_mod_sq.show()

## 注解计算图

In [13]:
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        annotate_mod = relay.quantize.annotate()(run_mod_sq)
annotate_mod.show()

## 模拟量化

In [14]:
from tvm.relay.quantize import calibrate

# 定义校准数据集
def data_iter(input_name, input_shape, num=1):
    for _ in range(num):
        yield {input_name: np.random.normal(size=input_shape)}

dataset = data_iter(input_name, input_shape)

calibrate_pass = tvm.transform.module_pass(
    calibrate(dataset), opt_level=1, name="QuantizeCalibrate"
)
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        calibrate_mod = calibrate_pass(annotate_mod)
calibrate_mod.show()

## 量化实现

In [15]:
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False
    ):
        run_mod_r = relay.quantize.realize()(calibrate_mod)
run_mod_r.show()

折叠常量：

In [16]:
with tvm.transform.PassContext(opt_level=3):
    run_mod_r = relay.transform.FoldConstant()(run_mod_r)
    run_mod_r = relay.transform.SimplifyInference()(run_mod_r)
run_mod_r.show()