测试 ImageNet 分类精度

测试 ImageNet 分类精度#

import tensorflow as tf
try:
    tf1 = tf.compat.v1
except (ImportError, AttributeError):
    tf1 = tf
%cd /media/pc/data/lxw/ai/tasks/models/research/slim
from nets import resnet_v2
import tf_slim as slim
import numpy as np
from tvm_book.metric.classification import Accuracy, TopKAccuracy
from tvm_book.data.classification import ImageFolderDataset

# @tf.function
def preprocessing(
    image,
    use_grayscale=False,
    central_fraction=0.875,
    central_crop=True,
    height=299,
    width=299,
    mean: tuple[float, ...] = (0.485, 0.456, 0.406),
    std: tuple[float, ...] = (1, 1, 1)
):
    image = tf.convert_to_tensor(image)
    if image.dtype != tf.float32:
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    if use_grayscale:
        image = tf.image.rgb_to_grayscale(image)
    if central_crop and central_fraction:
        image = tf.image.central_crop(image, central_fraction=central_fraction)
    if height and width:
        image = tf.expand_dims(image, 0)
        image = tf.image.resize(image, [height, width],
                                method='bilinear',
                                preserve_aspect_ratio=False,
                                antialias=False)
        image = tf.squeeze(image, [0])
    image = tf.subtract(image, tf.constant(mean, dtype=tf.float32))
    image = tf.divide(image, tf.constant(std, dtype=tf.float32))
    return image

class ResnetV2_50(tf.keras.Model):
    def __init__(self, trainable=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.trainable = trainable

    @tf.function(input_signature=[tf.TensorSpec([1, 3, 299, 299], 
                                                 tf.float32, name="data")])
    @tf1.keras.utils.track_tf1_style_variables
    def call(self, x):
        # x = tf.convert_to_tensor(x) # 确保输入是 tensor
        x = tf.transpose(x, perm=(0, 2, 3, 1)) # NCHW -> NHWC
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            logits, end_points = resnet_v2.resnet_v2_50(
                x, 
                num_classes=1001,
                global_pool=True,
                is_training=self.trainable,
                scope="resnet_v2_50"
            )
        del end_points
        return tf.nn.softmax(logits)
/media/pc/data/lxw/ai/tasks/models/research/slim
WARNING:tensorflow:From /tmp/ipykernel_2381233/575585929.py:50: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.
2023-06-21 16:03:37.608487: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:03:37.815194: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-21 16:03:37.817549: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-21 16:03:40.697419: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
model = ResnetV2_50()
model(tf.ones(shape=(1, 3, 299, 299), dtype=tf.float32))
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = "/media/pc/data/board/arria10/lxw/tests/npu_user_demos/models/resnet50_v2_tf/weight/resnet_v2_50.ckpt"
ckpt.restore(ckpt_path) # 更新模型参数
root = "/media/pc/data/lxw/home/data/datasets/ILSVRC/val"
valset = ImageFolderDataset(root)
WARNING:tensorflow:From /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/checkpoint/checkpoint.py:1426: NameBasedSaverStatus.__init__ (from tensorflow.python.checkpoint.checkpoint) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
2023-06-21 16:03:52.183666: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`layer.updates` will be removed in a future version. '
/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
from tqdm import tqdm

metric = Accuracy()
top5_metric = TopKAccuracy(top_k=5)
for k, (image, label_id) in tqdm(enumerate(valset)):
    processed_image = preprocessing(
        image,
        use_grayscale=False,
        central_fraction=0.875,
        central_crop=True,
        height=299,
        width=299,
        mean=(0.485, 0.456, 0.406),
        std=(1, 1, 1)
    )
    np_processed_images = np.expand_dims(processed_image.numpy(), axis=0)
    np_processed_images = np_processed_images.transpose(0, 3, 1, 2)
    outputs = model(np_processed_images)
    outputs = outputs.numpy()
    metric.update(labels=np.array([label_id+1]), preds=outputs)
    top5_metric.update(labels=np.array([label_id+1]), preds=outputs)
    if k%5000==0:
        print(f"{k+1}: {metric} {top5_metric}")
1: Accuracy: {'Accuracy': 0.0} TopKAccuracy: {'top_5_accuracy': 1.0}
1001: Accuracy: {'Accuracy': 0.8861138861138861} TopKAccuracy: {'top_5_accuracy': 0.967032967032967}
2001: Accuracy: {'Accuracy': 0.8275862068965517} TopKAccuracy: {'top_5_accuracy': 0.9545227386306847}
3001: Accuracy: {'Accuracy': 0.7944018660446518} TopKAccuracy: {'top_5_accuracy': 0.9456847717427525}
4001: Accuracy: {'Accuracy': 0.773056735816046} TopKAccuracy: {'top_5_accuracy': 0.9380154961259685}
5001: Accuracy: {'Accuracy': 0.7966406718656269} TopKAccuracy: {'top_5_accuracy': 0.9444111177764447}
6001: Accuracy: {'Accuracy': 0.7998666888851858} TopKAccuracy: {'top_5_accuracy': 0.9418430261623063}
7001: Accuracy: {'Accuracy': 0.8070275674903585} TopKAccuracy: {'top_5_accuracy': 0.9448650192829596}
8001: Accuracy: {'Accuracy': 0.8115235595550556} TopKAccuracy: {'top_5_accuracy': 0.9472565929258843}
9001: Accuracy: {'Accuracy': 0.8029107876902566} TopKAccuracy: {'top_5_accuracy': 0.9463392956338185}
10001: Accuracy: {'Accuracy': 0.797920207979202} TopKAccuracy: {'top_5_accuracy': 0.9464053594640536}
11001: Accuracy: {'Accuracy': 0.7976547586583038} TopKAccuracy: {'top_5_accuracy': 0.9478229251886192}
12001: Accuracy: {'Accuracy': 0.7937671860678277} TopKAccuracy: {'top_5_accuracy': 0.9466711107407716}
13001: Accuracy: {'Accuracy': 0.7934774248134759} TopKAccuracy: {'top_5_accuracy': 0.9490808399353896}
14001: Accuracy: {'Accuracy': 0.7920862795514606} TopKAccuracy: {'top_5_accuracy': 0.9500035711734877}
15001: Accuracy: {'Accuracy': 0.7920805279648023} TopKAccuracy: {'top_5_accuracy': 0.9502699820011999}
16001: Accuracy: {'Accuracy': 0.79063808511968} TopKAccuracy: {'top_5_accuracy': 0.9493781638647585}
17001: Accuracy: {'Accuracy': 0.7963649197106053} TopKAccuracy: {'top_5_accuracy': 0.9503558614199165}
18001: Accuracy: {'Accuracy': 0.7944558635631354} TopKAccuracy: {'top_5_accuracy': 0.9509471696016888}
19001: Accuracy: {'Accuracy': 0.7964843955581286} TopKAccuracy: {'top_5_accuracy': 0.9505289195305511}
20001: Accuracy: {'Accuracy': 0.7956102194890255} TopKAccuracy: {'top_5_accuracy': 0.9500524973751312}
21001: Accuracy: {'Accuracy': 0.7917718203895052} TopKAccuracy: {'top_5_accuracy': 0.9486691109947145}
22001: Accuracy: {'Accuracy': 0.7905095222944412} TopKAccuracy: {'top_5_accuracy': 0.9483659833643925}
23001: Accuracy: {'Accuracy': 0.7878353115081953} TopKAccuracy: {'top_5_accuracy': 0.9470023042476414}
24001: Accuracy: {'Accuracy': 0.783092371151202} TopKAccuracy: {'top_5_accuracy': 0.9447106370567893}
25001: Accuracy: {'Accuracy': 0.7784488620455182} TopKAccuracy: {'top_5_accuracy': 0.9422023119075237}
26001: Accuracy: {'Accuracy': 0.7746240529210415} TopKAccuracy: {'top_5_accuracy': 0.9405792084919811}
27001: Accuracy: {'Accuracy': 0.7722306581237732} TopKAccuracy: {'top_5_accuracy': 0.939705936817155}
28001: Accuracy: {'Accuracy': 0.7715795864433413} TopKAccuracy: {'top_5_accuracy': 0.93853790936038}
29001: Accuracy: {'Accuracy': 0.7730767904555015} TopKAccuracy: {'top_5_accuracy': 0.9387952139581394}
30001: Accuracy: {'Accuracy': 0.7701743275224159} TopKAccuracy: {'top_5_accuracy': 0.9371687610412986}
31001: Accuracy: {'Accuracy': 0.7701687042353472} TopKAccuracy: {'top_5_accuracy': 0.9364536627850715}
32001: Accuracy: {'Accuracy': 0.7649135964501109} TopKAccuracy: {'top_5_accuracy': 0.9345332958345052}
33001: Accuracy: {'Accuracy': 0.763219296384958} TopKAccuracy: {'top_5_accuracy': 0.9330020302415079}
34001: Accuracy: {'Accuracy': 0.7620364106938031} TopKAccuracy: {'top_5_accuracy': 0.9323549307373312}
35001: Accuracy: {'Accuracy': 0.7606925516413817} TopKAccuracy: {'top_5_accuracy': 0.931830519128025}
36001: Accuracy: {'Accuracy': 0.7600344434876809} TopKAccuracy: {'top_5_accuracy': 0.9317241187744785}
37001: Accuracy: {'Accuracy': 0.7596821707521418} TopKAccuracy: {'top_5_accuracy': 0.9308667333315316}
38001: Accuracy: {'Accuracy': 0.7573484908291888} TopKAccuracy: {'top_5_accuracy': 0.9299492118628457}
39001: Accuracy: {'Accuracy': 0.7569036691366888} TopKAccuracy: {'top_5_accuracy': 0.9290274608343376}
40001: Accuracy: {'Accuracy': 0.755331116722082} TopKAccuracy: {'top_5_accuracy': 0.9278518037049074}
41001: Accuracy: {'Accuracy': 0.7542742859930246} TopKAccuracy: {'top_5_accuracy': 0.9273432355308407}
42001: Accuracy: {'Accuracy': 0.7525058927168401} TopKAccuracy: {'top_5_accuracy': 0.926287469345968}
43001: Accuracy: {'Accuracy': 0.7511918327480757} TopKAccuracy: {'top_5_accuracy': 0.9253970837887491}
44001: Accuracy: {'Accuracy': 0.7500738619576828} TopKAccuracy: {'top_5_accuracy': 0.9252971523374469}
45001: Accuracy: {'Accuracy': 0.7486500299993334} TopKAccuracy: {'top_5_accuracy': 0.9244016799626675}
46001: Accuracy: {'Accuracy': 0.7479402621682137} TopKAccuracy: {'top_5_accuracy': 0.9242190387165496}
47001: Accuracy: {'Accuracy': 0.7476436671560179} TopKAccuracy: {'top_5_accuracy': 0.9243420352758452}
48001: Accuracy: {'Accuracy': 0.7482135788837733} TopKAccuracy: {'top_5_accuracy': 0.924876565071561}
49001: Accuracy: {'Accuracy': 0.7459643680741209} TopKAccuracy: {'top_5_accuracy': 0.9240627742290973}
3it [00:00,  7.35it/s]1003it [01:04, 15.83it/s]2003it [02:08, 15.98it/s]3003it [03:16, 14.74it/s]4003it [04:24, 14.77it/s]5003it [05:40, 14.19it/s]6002it [06:54, 14.20it/s]7002it [08:07, 13.61it/s]8003it [09:18, 14.59it/s]9003it [10:35, 10.82it/s]10002it [11:50, 13.20it/s]11002it [13:06, 12.81it/s]12002it [14:33,  9.03it/s]13003it [16:00, 11.30it/s]14003it [17:32,  9.97it/s]15003it [19:26,  9.92it/s]16002it [21:14,  8.43it/s]17002it [23:00,  9.31it/s]18002it [24:42,  9.14it/s]19002it [26:33,  5.32it/s]20003it [28:27, 13.36it/s]21003it [29:46, 12.64it/s]22003it [31:06, 10.94it/s]23002it [32:29, 11.73it/s]24002it [34:01, 10.85it/s]25004it [35:21, 12.91it/s]26004it [36:39, 12.72it/s]27002it [38:02, 12.00it/s]28002it [39:24, 11.25it/s]29002it [40:43, 12.29it/s]30002it [42:03, 12.84it/s]31002it [43:28, 12.28it/s]32002it [44:46, 12.47it/s]33002it [46:19, 11.31it/s]34002it [48:07, 10.74it/s]35001it [49:39, 12.81it/s]36002it [51:19,  6.74it/s]37003it [52:46, 11.11it/s]38002it [54:22,  8.04it/s]39002it [55:55, 11.22it/s]40002it [57:19, 11.78it/s]41002it [58:41, 11.34it/s]42002it [1:00:10, 11.12it/s]43003it [1:02:12,  8.84it/s]44002it [1:04:22,  6.99it/s]45002it [1:06:24,  8.11it/s]46002it [1:08:32,  8.41it/s]47002it [1:10:45,  8.86it/s]48002it [1:12:56,  7.35it/s]49002it [1:14:58,  8.05it/s]50000it [1:17:01, 10.82it/s]
print(f"{metric} {top5_metric}")
Accuracy: {'Accuracy': 0.74754} TopKAccuracy: {'top_5_accuracy': 0.92474}