模型翻译#

import set_env

定义简单模型:

import torch

class Conv2D1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)

    def forward(self, data):
        return self.conv(data)

MSCGraph 与 PyTorch 模型互转#

import numpy as np

import torch
from torch.nn import Module

import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.framework.torch import codegen
torch_model = Conv2D1()
input_info = [([1, 3, 10, 10], "float32")]

torch 模型转换为 MSCGraph:

graph, weights = translate.from_torch(torch_model, input_info, via_relax=False)

MSCGraph 再转换会 torch 模型:

model = codegen.to_torch(graph, weights)

验证一致性:

torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]
with torch.no_grad():
    golden = torch_model(*torch_datas)
with torch.no_grad():
    if not graph.get_inputs():
        result = model()
    else:
        result = model(*torch_datas)
if not isinstance(golden, (list, tuple)):
    golden = [golden]
if not isinstance(result, (list, tuple)):
    result = [result]
assert len(golden) == len(result), f"golden {len(golden)} mismatch with result {len(result)}"
for gol_r, new_r in zip(golden, result):
    if isinstance(gol_r, torch.Tensor):
        tvm.testing.assert_allclose(
            gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5
        )
    else:
        assert gol_r == new_r

转换为 relay#

def _valid_target(target):
    if not target:
        return target
    if target == "ignore":
        return None
    if target == "cuda" and not tvm.cuda().exist:
        return None
    if isinstance(target, str):
        target = tvm.target.Target(target)
    return target
def _run_relax(relax_mod, target, datas):
    relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
    with tvm.transform.PassContext(opt_level=3):
        relax_exec = tvm.relax.build(relax_mod, target)
        runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu())
    res = runnable["main"](*datas)
    if isinstance(res, tvm.runtime.NDArray):
        return [res.asnumpy()]
    return [e.asnumpy() for e in res]
from tvm.relax.frontend.torch import from_fx
from tvm.relay.frontend import from_pytorch
from torch import fx
from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
opt_config = None
codegen_config = None 
build_target=None
graph_model = fx.symbolic_trace(torch_model)
with torch.no_grad():
    expected = from_fx(graph_model, input_info)
expected = tvm.relax.transform.CanonicalizeBindings()(expected)

# graph from relay
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
with torch.no_grad():
    scripted_model = torch.jit.trace(torch_model, tuple(torch_datas)).eval()  # type: ignore
shape_list = [("input" + str(idx), i) for idx, i in enumerate(input_info)]
relay_mod, params = from_pytorch(scripted_model, shape_list)
graph, weights = translate.from_relay(relay_mod, params, opt_config=opt_config)
# to relax
codegen_config = codegen_config or {}
codegen_config.update({"explicit_name": False, "from_relay": True})
mod = tvm_codegen.to_relax(graph, weights, codegen_config)
if build_target:
    build_target = _valid_target(build_target)
    if not build_target:
        exit()
    tvm_datas = [tvm.nd.array(i) for i in datas]
    expected_res = _run_relax(expected, build_target, tvm_datas)
    if not graph.get_inputs():
        tvm_datas = []
    res = _run_relax(mod, build_target, tvm_datas)
    for exp_r, new_r in zip(expected_res, res):
        tvm.testing.assert_allclose(exp_r, new_r, atol=1e-5, rtol=1e-5)
else:
    tvm.ir.assert_structural_equal(mod, expected)

转换为 relax#

import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen


def _verify_model(torch_model, input_info, opt_config=None):
    graph_model = fx.symbolic_trace(torch_model)
    with torch.no_grad():
        orig_mod = from_fx(graph_model, input_info)

    target = "llvm"
    dev = tvm.cpu()
    args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for shape, dtype in input_info]

    def _tvm_runtime_to_np(obj):
        if isinstance(obj, tvm.runtime.NDArray):
            return obj.numpy()
        elif isinstance(obj, tvm.runtime.ShapeTuple):
            return np.array(obj, dtype="int64")
        elif isinstance(obj, (list, tvm.ir.container.Array)):
            return [_tvm_runtime_to_np(item) for item in obj]
        elif isinstance(obj, tuple):
            return tuple(_tvm_runtime_to_np(item) for item in obj)
        else:
            return obj

    def _run_relax(relax_mod):
        relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
        relax_exec = tvm.relax.build(relax_mod, target)
        vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)
        res = vm_runner["main"](*args)

        return _tvm_runtime_to_np(res)

    rt_mod = tvm_codegen.to_relax(
        *translate.from_relax(orig_mod, opt_config=opt_config),
        codegen_config={"explicit_name": False},
    )

    orig_output = _run_relax(orig_mod)
    rt_output = _run_relax(rt_mod)
    tvm.testing.assert_allclose(orig_output, rt_output)
_verify_model(torch_model, input_info, opt_config=None)