翻译 Relay 代码

目录

翻译 Relay 代码#

import set_env
import numpy as np

import torch
from torch import fx
from torch.nn import Module

import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.relay.frontend import from_pytorch
# from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
def _valid_target(target):
    if not target:
        return target
    if target == "ignore":
        return None
    if target == "cuda" and not tvm.cuda().exist:
        return None
    if isinstance(target, str):
        target = tvm.target.Target(target)
    return target


def _run_relax(relax_mod, target, datas):
    relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
    with tvm.transform.PassContext(opt_level=3):
        relax_exec = tvm.relax.build(relax_mod, target)
        runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu())
    res = runnable["main"](*datas)
    if isinstance(res, tvm.runtime.NDArray):
        return [res.asnumpy()]
    return [e.asnumpy() for e in res]


def verify_model(torch_model, input_info, opt_config=None, codegen_config=None, build_target=None):
    """Compare relax with relay"""

    graph_model = fx.symbolic_trace(torch_model)
    with torch.no_grad():
        expected = from_fx(graph_model, input_info)
    expected = tvm.relax.transform.CanonicalizeBindings()(expected)

    # graph from relay
    datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
    torch_datas = [torch.from_numpy(i) for i in datas]
    with torch.no_grad():
        scripted_model = torch.jit.trace(torch_model, tuple(torch_datas)).eval()  # type: ignore
    shape_list = [("input" + str(idx), i) for idx, i in enumerate(input_info)]
    relay_mod, params = from_pytorch(scripted_model, shape_list)
    graph, weights = translate.from_relay(relay_mod, params, opt_config=opt_config)
    # to relax
    codegen_config = codegen_config or {}
    codegen_config.update({"explicit_name": False, "from_relay": True})
    mod = tvm_codegen.to_relax(graph, weights, codegen_config)
    if build_target:
        build_target = _valid_target(build_target)
        if not build_target:
            return
        tvm_datas = [tvm.nd.array(i) for i in datas]
        expected_res = _run_relax(expected, build_target, tvm_datas)
        if not graph.get_inputs():
            tvm_datas = []
        res = _run_relax(mod, build_target, tvm_datas)
        for exp_r, new_r in zip(expected_res, res):
            tvm.testing.assert_allclose(exp_r, new_r, atol=1e-5, rtol=1e-5)
    else:
        tvm.ir.assert_structural_equal(mod, expected)
def verify_model(torch_model, input_info, via_relax=True):
    """比较 torch 模型结果"""

    graph, weights = translate.from_torch(torch_model, input_info, via_relax=via_relax)
    model = codegen.to_torch(graph, weights)
    torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]
    with torch.no_grad():
        golden = torch_model(*torch_datas)
    with torch.no_grad():
        if not graph.get_inputs():
            result = model()
        else:
            result = model(*torch_datas)
    if not isinstance(golden, (list, tuple)):
        golden = [golden]
    if not isinstance(result, (list, tuple)):
        result = [result]
    assert len(golden) == len(result), "golden {} mismatch with result {}".format(
        len(golden), len(result)
    )
    for gol_r, new_r in zip(golden, result):
        if isinstance(gol_r, torch.Tensor):
            tvm.testing.assert_allclose(
                gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5
            )
        else:
            assert gol_r == new_r

conv2d#

import torch
from torch import fx
from torch.nn import Module
class Conv2D1(Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)

    def forward(self, data):
        return self.conv(data)

class Conv2D2(Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)

    def forward(self, data):
        return self.conv(data)

input_info = [([1, 3, 10, 10], "float32")]
verify_model(Conv2D1(), input_info)
verify_model(Conv2D2(), input_info)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 21
     18         return self.conv(data)
     20 input_info = [([1, 3, 10, 10], "float32")]
---> 21 verify_model(Conv2D1(), input_info)
     22 verify_model(Conv2D2(), input_info)

Cell In[4], line 5, in verify_model(torch_model, input_info, via_relax)
      2 """比较 torch 模型结果"""
      4 graph, weights = translate.from_torch(torch_model, input_info, via_relax=via_relax)
----> 5 model = codegen.to_torch(graph, weights)
      6 torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]
      7 with torch.no_grad():

NameError: name 'codegen' is not defined