TensorFlow2 Keras 推理

TensorFlow2 Keras 推理#

参考:migrating_checkpoints

下面以模型 resnet_v2_50 为例展示。

需要克隆项目 models,然后执行如下操作。

import os
m_gpu = -1 # 禁用 GPU
os.environ['CUDA_VISIBLE_DEVICES'] = str(m_gpu)
os.environ['CUDA_LAUNCH_BLOCKING'] = str(m_gpu)
import tensorflow as tf
try:
    tf1 = tf.compat.v1
except (ImportError, AttributeError):
    tf1 = tf
tf.get_logger().setLevel('ERROR')
2023-06-21 16:49:34.172559: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:49:34.247466: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:49:34.248487: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-21 16:49:35.798317: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

切换到 models/research/slim 目录下:

%cd /media/pc/data/lxw/ai/tasks/models/research/slim
/media/pc/data/lxw/ai/tasks/models/research/slim

将 TF1 升级为 TF2:

from nets import resnet_v2
import tf_slim as slim

class ResnetV2_50(tf.keras.Model):
    def __init__(self, trainable=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.trainable = trainable

    @tf.function(input_signature=[tf.TensorSpec([1, 3, 299, 299], 
                                                 tf.float32, name="data")])
    @tf1.keras.utils.track_tf1_style_variables
    def call(self, x):
        # x = tf.convert_to_tensor(x, tf.float32) # 确保输入是 tensor
        x = tf.transpose(x, perm=(0, 2, 3, 1)) # NCHW -> NHWC
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(
                x, 
                num_classes=1001,
                global_pool=True,
                is_training=self.trainable,
                scope="resnet_v2_50"
            )
        del end_points
        return tf.nn.softmax(logits)

预处理:

from PIL import Image
import numpy as np
from nets import resnet_v2
from tvm_book.data.classification import ImageFolderDataset
import tf_slim as slim
import tensorflow as tf


@tf.function
def preprocessing(
    image,
    use_grayscale=False,
    central_fraction=0.875,
    central_crop=True,
    height=299,
    width=299,
    mean: tuple[float, ...] = (0.485, 0.456, 0.406),
    std: tuple[float, ...] = (1, 1, 1)
):
    # image = tf.constant(image)
    if image.dtype != tf.float32:
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    if use_grayscale:
        image = tf.image.rgb_to_grayscale(image)
    if central_crop and central_fraction:
        image = tf.image.central_crop(image, central_fraction=central_fraction)
    if height and width:
        image = tf.expand_dims(image, 0)
        image = tf.image.resize(image, [height, width],
                                method='bilinear',
                                preserve_aspect_ratio=False,
                                antialias=False)
        image = tf.squeeze(image, [0])
    image = tf.subtract(image, mean)
    image = tf.divide(image, std)
    return image


# 预处理
root = "/media/pc/data/lxw/home/data/datasets/ILSVRC/val"
valset = ImageFolderDataset(root)
image, label_id = valset[1001]
model_dir = 'temp/resnet_v2_50'
# remove_dir(model_dir)
processed_image = preprocessing(
    image,
    use_grayscale=False,
    central_fraction=0.875,
    central_crop=True,
    height=299,
    width=299,
    mean=(0.485, 0.456, 0.406),
    std=(1, 1, 1)
)
np_processed_images = np.expand_dims(processed_image.numpy(), axis=0)
np_processed_images = np_processed_images.transpose(0, 3, 1, 2)
2023-06-21 16:49:39.660093: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-06-21 16:49:39.660172: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: Alg
2023-06-21 16:49:39.660183: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: Alg
2023-06-21 16:49:39.660370: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 530.30.2
2023-06-21 16:49:39.660427: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 530.30.2
2023-06-21 16:49:39.660443: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 530.30.2

前向推理:

model = ResnetV2_50()
model(tf.ones(shape=(1, 3, 299, 299), dtype=tf.float32))
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = "/media/pc/data/board/arria10/lxw/tests/npu_user_demos/models/resnet50_v2_tf/weight/resnet_v2_50.ckpt"
ckpt.restore(ckpt_path) # 更新模型参数
outputs = model(np_processed_images)
outputs = outputs.numpy()
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`layer.updates` will be removed in a future version. '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
model.summary()
Model: "resnet_v2_50"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 25,615,849
Trainable params: 0
Non-trainable params: 25,615,849
_________________________________________________________________

打印标签信息:

from tvm_book.data.imagenet.classification import ImageNet1kAttr

imagenet1k_attr = ImageNet1kAttr()
sorted_inds = outputs[0].argsort()[::-1]
topk = 5
print(f"真实标签:{imagenet1k_attr.classes_long[label_id]}")
for sorted_ind in sorted_inds[:topk]:
    label = imagenet1k_attr.classes_long[sorted_ind-1]
    print(f"{sorted_ind-1}: {label.ljust(38)}\t{outputs[0, sorted_ind]}")
真实标签:water ouzel, dipper
20: water ouzel, dipper                   	0.9207783937454224
143: oystercatcher, oyster catcher         	0.014078204520046711
141: redshank, Tringa totanus              	0.0032907347194850445
146: albatross, mollymawk                  	0.0032017454504966736
139: ruddy turnstone, Arenaria interpres   	0.002742304001003504

将其模型和参数与加载下来:

# # model = ResnetV2_50()
# inputs = tf.keras.Input(shape=(224, 224, 3), dtype=tf.float32, name="data")
# outputs = model(inputs)
# model2 = tf.keras.Model(inputs=inputs, outputs=outputs, name="resnet_v2_50_model")

# model2.save(module_with_signature_path)
module_with_signature_path = "/tmp/resnet_v2_50_keras"
model.save(module_with_signature_path)
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`layer.updates` will be removed in a future version. '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
imported_with_signatures = tf.saved_model.load(module_with_signature_path)
infer = imported_with_signatures.signatures['serving_default']
labeling = infer(tf.constant(np_processed_images))
from tvm_book.data.imagenet.classification import ImageNet1kAttr

outputs = labeling['output_1'].numpy()
imagenet1k_attr = ImageNet1kAttr()
sorted_inds = outputs[0].argsort()[::-1]
topk = 5
print(f"真实标签:{imagenet1k_attr.classes_long[label_id]}")
for sorted_ind in sorted_inds[:topk]:
    label = imagenet1k_attr.classes_long[sorted_ind-1]
    print(f"{sorted_ind-1}: {label.ljust(38)}\t{outputs[0, sorted_ind]}")
真实标签:water ouzel, dipper
20: water ouzel, dipper                   	0.9207783937454224
143: oystercatcher, oyster catcher         	0.014078204520046711
141: redshank, Tringa totanus              	0.0032907347194850445
146: albatross, mollymawk                  	0.0032017454504966736
139: ruddy turnstone, Arenaria interpres   	0.002742304001003504

转换为 ONNX 模型#

Keras 模型转换 ONNX

import tf2onnx
import onnx

input_signature = [tf.TensorSpec([None, 3, 299, 299], tf.float32, name="data")]
onnx_model, external_tensor_storage = tf2onnx.convert.from_keras(model, input_signature)
onnx.save(onnx_model, "/tmp/resnet_v2_50_tf.onnx")
2023-06-21 16:50:08.883734: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-06-21 16:50:08.883960: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-06-21 16:50:13.670424: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-06-21 16:50:13.671135: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session

构建库:

import set_env
from tvm.relay.frontend import from_onnx

shape_dict = {"data": [1, 3, 299, 299]}

graph_def = onnx.load("/tmp/resnet_v2_50_tf.onnx")
mod, params = from_onnx(
    graph_def,
    shape_dict,
    freeze_params=True
)

推理:

import tvm
from tvm import relay

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, "llvm", params=params)
    
inputs_dict = {"data": np_processed_images}
mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
mlib_proxy.run(**inputs_dict)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.

验证一致性:

np.testing.assert_allclose(
    labeling['output_1'].numpy(), 
    mlib_proxy.get_output(0).numpy(),
    rtol=1e-07, atol=1e-5
)

转换为 TFLite 模型#

import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(module_with_signature_path)
tflite_model = converter.convert()

# Save the model.
with open('temp/resnet_v2_50.tflite', 'wb') as f:
    f.write(tflite_model)
2023-06-21 16:52:14.629868: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-06-21 16:52:14.630049: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2023-06-21 16:52:14.649710: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/resnet_v2_50_keras
2023-06-21 16:52:14.679437: I tensorflow/cc/saved_model/reader.cc:89] Reading meta graph with tags { serve }
2023-06-21 16:52:14.679522: I tensorflow/cc/saved_model/reader.cc:130] Reading SavedModel debug info (if present) from: /tmp/resnet_v2_50_keras
2023-06-21 16:52:14.765580: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:353] MLIR V1 optimization pass is not enabled
2023-06-21 16:52:14.785594: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2023-06-21 16:52:15.679311: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /tmp/resnet_v2_50_keras
2023-06-21 16:52:15.933247: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 1283554 microseconds.
2023-06-21 16:52:16.844598: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-06-21 16:52:18.902510: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2116] Estimated count of arithmetic ops: 13.119 G  ops, equivalently 6.559 G  MACs

加载 TFLite 模型:

import tflite


with open('temp/resnet_v2_50.tflite', "rb") as fp:
    tflite_model_buf = fp.read()

tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
mod, params = relay.frontend.from_tflite(
    tflite_model, shape_dict=shape_dict, 
    dtype_dict={"data": "float32"}
)
desired_layouts = {
    # 'image.resize2d': ['NCHW'],
    'nn.conv2d': ['NCHW', 'default'],
    'nn.max_pool2d': ['NCHW', 'default'],
    'nn.avg_pool2d': ['NCHW', 'default'],
}
# NHWC 将布局转换为 NCHW 且移除未使用算子
seq = tvm.transform.Sequential([
    relay.transform.RemoveUnusedFunctions(),
    relay.transform.ConvertLayout(desired_layouts)
])
with tvm.transform.PassContext(opt_level=3):
    mod = seq(mod)

验证结果一致性:

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, "llvm", params=params)
    
inputs_dict = {"data": np_processed_images}
mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
mlib_proxy.run(**inputs_dict)
np.testing.assert_allclose(
    labeling['output_1'].numpy(), 
    mlib_proxy.get_output(0).numpy(),
    rtol=1e-07, atol=1e-5
)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[18], line 7
      5 mlib_proxy = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu()))
      6 mlib_proxy.run(**inputs_dict)
----> 7 np.testing.assert_allclose(
      8     labeling['output_1'].numpy(), 
      9     mlib_proxy.get_output(0).numpy(),
     10     rtol=1e-07, atol=1e-5
     11 )

    [... skipping hidden 1 frame]

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/numpy/testing/_private/utils.py:844, in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)
    840         err_msg += '\n' + '\n'.join(remarks)
    841         msg = build_err_msg([ox, oy], err_msg,
    842                             verbose=verbose, header=header,
    843                             names=('x', 'y'), precision=precision)
--> 844         raise AssertionError(msg)
    845 except ValueError:
    846     import traceback

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-05

Mismatched elements: 978 / 1001 (97.7%)
Max absolute difference: 0.92047846
Max relative difference: 3068.9646
 x: array([[3.429268e-05, 1.693668e-05, 3.029113e-05, ..., 1.208637e-05,
        9.920573e-06, 2.882769e-05]], dtype=float32)
 y: array([[1.464797e-04, 1.271962e-04, 3.613982e-04, ..., 5.739641e-05,
        8.909408e-05, 1.503596e-03]], dtype=float32)

警告

TFLite 转换出现了问题,暂时搁置。