PyTorch 量化#
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import tvm
from tvm import relay
torch.manual_seed(0)
torch.set_grad_enabled(False)
def list_ops(expr):
"""list_ops"""
class OpLister(tvm.relay.ExprVisitor):
"""OpLister inherits from ExprVisitor"""
def visit_op(self, op):
if op not in self.node_set:
self.node_list.append(op)
return super().visit_op(op)
def list_nodes(self, expr):
self.node_set = {}
self.node_list = []
self.visit(expr)
return self.node_list
return OpLister().list_nodes(expr)
class Demo(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(16, 64, 3, 1, 1, bias=False, groups=16)
# self.prelu = nn.PReLU(64)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor):
x = self.conv(x)
# x = self.prelu(x)
x = self.relu(x)
return x
class Add1(nn.Module):
def forward(self, x):
return x + 1
input_shape = [2]
input_data = torch.rand(input_shape).float()
input_data
tensor([0.4963, 0.7682])
compiled_input = {"data": input_data.numpy()}
dev = tvm.cpu()
target = "llvm"
input_shapes = [("data", input_shape)]
model = Add1().float().eval()
trace_model = torch.jit.trace(model, [input_data.clone()])
trace_model = trace_model.float().eval()
mod, params = relay.frontend.from_pytorch(trace_model, input_shapes)
with tvm.transform.PassContext(opt_level=3):
exe = relay.create_executor(
"vm", mod=mod, params=params, device=dev, target=target
).evaluate()
result = exe(**compiled_input)
result
<tvm.nd.NDArray shape=(2,), cpu(0)>
array([1.4962566, 1.7682219], dtype=float32)