模型翻译#
import set_env
定义简单模型:
import torch
class Conv2D1(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
MSCGraph
与 PyTorch 模型互转#
import numpy as np
import torch
from torch.nn import Module
import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.framework.torch import codegen
torch_model = Conv2D1()
input_info = [([1, 3, 10, 10], "float32")]
torch
模型转换为 MSCGraph
:
graph, weights = translate.from_torch(torch_model, input_info, via_relax=False)
MSCGraph
再转换会 torch
模型:
model = codegen.to_torch(graph, weights)
验证一致性:
torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]
with torch.no_grad():
golden = torch_model(*torch_datas)
with torch.no_grad():
if not graph.get_inputs():
result = model()
else:
result = model(*torch_datas)
if not isinstance(golden, (list, tuple)):
golden = [golden]
if not isinstance(result, (list, tuple)):
result = [result]
assert len(golden) == len(result), f"golden {len(golden)} mismatch with result {len(result)}"
for gol_r, new_r in zip(golden, result):
if isinstance(gol_r, torch.Tensor):
tvm.testing.assert_allclose(
gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5
)
else:
assert gol_r == new_r
转换为 relay#
def _valid_target(target):
if not target:
return target
if target == "ignore":
return None
if target == "cuda" and not tvm.cuda().exist:
return None
if isinstance(target, str):
target = tvm.target.Target(target)
return target
def _run_relax(relax_mod, target, datas):
relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
with tvm.transform.PassContext(opt_level=3):
relax_exec = tvm.relax.build(relax_mod, target)
runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu())
res = runnable["main"](*datas)
if isinstance(res, tvm.runtime.NDArray):
return [res.asnumpy()]
return [e.asnumpy() for e in res]
from tvm.relax.frontend.torch import from_fx
from tvm.relay.frontend import from_pytorch
from torch import fx
from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
opt_config = None
codegen_config = None
build_target=None
graph_model = fx.symbolic_trace(torch_model)
with torch.no_grad():
expected = from_fx(graph_model, input_info)
expected = tvm.relax.transform.CanonicalizeBindings()(expected)
# graph from relay
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
with torch.no_grad():
scripted_model = torch.jit.trace(torch_model, tuple(torch_datas)).eval() # type: ignore
shape_list = [("input" + str(idx), i) for idx, i in enumerate(input_info)]
relay_mod, params = from_pytorch(scripted_model, shape_list)
graph, weights = translate.from_relay(relay_mod, params, opt_config=opt_config)
# to relax
codegen_config = codegen_config or {}
codegen_config.update({"explicit_name": False, "from_relay": True})
mod = tvm_codegen.to_relax(graph, weights, codegen_config)
if build_target:
build_target = _valid_target(build_target)
if not build_target:
exit()
tvm_datas = [tvm.nd.array(i) for i in datas]
expected_res = _run_relax(expected, build_target, tvm_datas)
if not graph.get_inputs():
tvm_datas = []
res = _run_relax(mod, build_target, tvm_datas)
for exp_r, new_r in zip(expected_res, res):
tvm.testing.assert_allclose(exp_r, new_r, atol=1e-5, rtol=1e-5)
else:
tvm.ir.assert_structural_equal(mod, expected)
转换为 relax
#
import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
def _verify_model(torch_model, input_info, opt_config=None):
graph_model = fx.symbolic_trace(torch_model)
with torch.no_grad():
orig_mod = from_fx(graph_model, input_info)
target = "llvm"
dev = tvm.cpu()
args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for shape, dtype in input_info]
def _tvm_runtime_to_np(obj):
if isinstance(obj, tvm.runtime.NDArray):
return obj.numpy()
elif isinstance(obj, tvm.runtime.ShapeTuple):
return np.array(obj, dtype="int64")
elif isinstance(obj, (list, tvm.ir.container.Array)):
return [_tvm_runtime_to_np(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(_tvm_runtime_to_np(item) for item in obj)
else:
return obj
def _run_relax(relax_mod):
relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
relax_exec = tvm.relax.build(relax_mod, target)
vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)
res = vm_runner["main"](*args)
return _tvm_runtime_to_np(res)
rt_mod = tvm_codegen.to_relax(
*translate.from_relax(orig_mod, opt_config=opt_config),
codegen_config={"explicit_name": False},
)
orig_output = _run_relax(orig_mod)
rt_output = _run_relax(rt_mod)
tvm.testing.assert_allclose(orig_output, rt_output)
_verify_model(torch_model, input_info, opt_config=None)