翻译 PyTorch 代码#
import set_env
import numpy as np
import torch
from torch.nn import Module
import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.framework.torch import codegen
def verify_model(torch_model, input_info, via_relax=True):
"""比较 torch 模型结果"""
graph, weights = translate.from_torch(torch_model, input_info, via_relax=via_relax)
model = codegen.to_torch(graph, weights)
torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]
with torch.no_grad():
golden = torch_model(*torch_datas)
with torch.no_grad():
if not graph.get_inputs():
result = model()
else:
result = model(*torch_datas)
if not isinstance(golden, (list, tuple)):
golden = [golden]
if not isinstance(result, (list, tuple)):
result = [result]
assert len(golden) == len(result), "golden {} mismatch with result {}".format(
len(golden), len(result)
)
for gol_r, new_r in zip(golden, result):
if isinstance(gol_r, torch.Tensor):
tvm.testing.assert_allclose(
gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5
)
else:
assert gol_r == new_r
conv1d
#
class Conv1D1(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
class Conv1D2(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
def forward(self, data):
return self.conv(data)
input_info = [([1, 3, 10], "float32")]
for via_relax in [True, False]:
verify_model(Conv1D1(), input_info, via_relax)
verify_model(Conv1D2(), input_info, via_relax)