Tensorflow(pb) 转 ONNX

Tensorflow(pb) 转 ONNX#

参考: TVM Tensorflow 前端

下面以 mobilenet_v2 float_v2_1.4_224 为例,展示 Tensorflow pb 模型转换为 ONNX 模型的过程:

import numpy as np
import tensorflow as tf
from tensorflow.core.framework.graph_pb2 import GraphDef
import set_env # 加载 TVM
import tvm.relay.testing.tf as tf_testing
class MobilenetV2(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        path_model = tf_testing.get_workload_official(
        self.graph_def = self._read_graph_def(path_model)
        self.output_names = ['output']
    def _read_graph_def(self, frozen_path):
        with open(frozen_path, 'rb') as f:
            graph_def = GraphDef()
        return graph_def
    @tf.function(input_signature=[tf.TensorSpec([1, 3, 224, 224], 
                                                 tf.float32, name="input")])
    def call(self, x):
        x = tf.convert_to_tensor(x, tf.float32) # 确保输入是 tensor
        x = tf.transpose(x, perm=(0, 2, 3, 1)) # NCHW -> NHWC
        x = tf.graph_util.import_graph_def(
            self.graph_def, input_map={'input:0': x}, 
            # return_elements=["MobilenetV2/Logits/AvgPool:0"]
            # return_elements=["MobilenetV2/Logits/Conv2d_1c_1x1/BiasAdd:0"]
        return x
import tf2onnx
import onnx

input_signature = [tf.TensorSpec([1, 3, 224, 224], tf.float32, name="data")]
model = MobilenetV2()
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature)
onnx.save(onnx_model, ".temp/mobilenet_v2_tf.onnx")
from PIL import Image
image_size = 224
path = 'images/Giant_Panda_in_Beijing_Zoo_1.jpg' # 将要预测的图片路径

with Image.open(path) as im:
    if im.mode != "RGB":
    im = im.resize((224, 224))
    image = np.asarray(im)
image = image/128 -1
images = np.expand_dims(image, 0)
images = images.transpose((0, 3, 1, 2))
model = MobilenetV2()
tf_output = model(images)
import set_env
import tvm
from tvm import relay
from tvm.relay.frontend import from_onnx

shape_dict = {"data": [1, 3, 224, 224]}
mod, params = from_onnx(
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, "llvm", params=params)
inputs_dict = {"data": images}
mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
    rtol=1e-07, atol=1e-5
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.