Tensorflow1 前端#
下面以 mobilenet_v2 float_v2_1.4_224 为例,展示 Tensorflow 前端。
先运行简单的测试:
import numpy as np
import tensorflow as tf
try:
tf1 = tf.compat.v1
except (ImportError, AttributeError):
tf1 = tf
import set_env # 加载 TVM
import tvm.relay.testing.tf as tf_testing
import tvm
from tvm import relay
from tvm.contrib import graph_executor
shape = 1, 224, 224, 3
data = np.random.uniform(size=shape).astype("float32")
output_name = "MobilenetV2/Predictions/Reshape_1"
input_name = "input"
input_dict = {f"{input_name}:0": data}
with tf.Graph().as_default():
graph_def = tf_testing.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb",
)
# 调用实用程序将图定义导入默认 graph
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
with tf1.Session() as sess:
# 添加 shapes 到 graph
graph_def = tf_testing.AddShapesToGraphDef(sess, output_name)
# 获取 TF 结果
out_tensor = sess.graph.get_tensor_by_name(f"{output_name}:0")
tf_output = sess.run(out_tensor, input_dict)
# TVM 编译
mod, params = relay.frontend.from_tensorflow(
graph_def,
shape={
input_name: shape
}
)
desired_layouts = {
# 'image.resize2d': ['NCHW'],
'nn.conv2d': ['NCHW', 'default'],
'nn.max_pool2d': ['NCHW', 'default'],
'nn.avg_pool2d': ['NCHW', 'default'],
}
# 将布局转换为 NCHW
# RemoveUnusedFunctions 用于清理图。
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout(desired_layouts)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with relay.build_config(opt_level=3):
lib = relay.build(mod, target, params=params)
m = graph_executor.GraphModule(lib["default"](dev))
m.set_input(**{input_name: data})
m.run()
tvm_output = [m.get_output(kk).numpy() for kk in range(m.get_num_outputs())]
np.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5
)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.