升级 TF1 为 TF2#

import tensorflow as tf
try:
    tf1 = tf.compat.v1
except (ImportError, AttributeError):
    tf1 = tf

tf.get_logger().setLevel('ERROR')
2024-01-06 16:23:07.117814: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-01-06 16:23:07.160960: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-06 16:23:07.161006: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-06 16:23:07.161048: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-06 16:23:07.169897: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-01-06 16:23:07.170731: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-01-06 16:23:08.330236: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
temp_dir = ".temp" # 缓存目录
# log_dir = "/media/pc/data/lxw/ai/tasks/logs/tf1-tf2"

检查点迁移#

定义辅助函数:

def print_checkpoint(save_path, print_value=True, tab_size=10):
    """打印检查点信息"""
    shape_size = max(tab_size, 20)
    dtype_size = max(tab_size, 10)
    reader = tf.train.load_checkpoint(save_path)
    shapes = reader.get_variable_to_shape_map()
    dtypes = reader.get_variable_to_dtype_map()
    print(f"检查点: {save_path}")
    tt = "key".ljust(tab_size)
    tt += "\tshape".ljust(shape_size)
    tt += "\tdtype".ljust(dtype_size)
    tt += "\tvalue".ljust(tab_size)
    print(tt)
    print("="*tab_size*7)
    for key in shapes:
        tt = f"{key}".ljust(tab_size)
        tt += f"\t{shapes[key]}".ljust(shape_size)
        tt += f"\t{dtypes[key].name}".ljust(dtype_size)
        if print_value:
            tt += f"\t{reader.get_tensor(key)}".ljust(max(tab_size, 10))
        print(tt)

先看 TF1 的例子:

with tf.Graph().as_default() as g:
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                         initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.uint8, 
                         initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.uint8, 
                         initializer=tf1.zeros_initializer())
    with tf1.Session() as sess:
        saver = tf1.train.Saver()
        sess.run(a.assign(1))
        sess.run(b.assign(2))
        sess.run(c.assign(3))
        saver.save(sess, f'{temp_dir}/tf1-ckpt')
print_checkpoint(f'{temp_dir}/tf1-ckpt')
检查点: .temp/tf1-ckpt
key       	shape              	dtype    	value    
======================================================================
scoped/c  	[]                 	uint8    	3        
b         	[]                 	uint8    	2        
a         	[]                 	float32  	1.0      
2024-01-06 16:23:09.741226: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2024-01-06 16:23:09.762581: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:382] MLIR V1 optimization pass is not enabled

TF2 的例子:

a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped2'):
    c = tf.Variable(7.0, name='c')
ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save(f'{temp_dir}/tf2-ckpt')
print_checkpoint(save_path_v2, tab_size=32)
检查点: .temp/tf2-ckpt-1
key                             	shape                          	dtype                          	value                          
================================================================================================================================================================================================================================
variables/2/.ATTRIBUTES/VARIABLE_VALUE	[]                             	float32                        	7.0                            
variables/1/.ATTRIBUTES/VARIABLE_VALUE	[]                             	float32                        	6.0                            
variables/0/.ATTRIBUTES/VARIABLE_VALUE	[]                             	float32                        	5.0                            
save_counter/.ATTRIBUTES/VARIABLE_VALUE	[]                             	int64                          	1                              
_CHECKPOINTABLE_OBJECT_GRAPH    	[]                             	string                         	b"\n%\n\r\x08\x01\x12\tvariables\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n\x19\n\x05\x08\x03\x12\x010\n\x05\x08\x04\x12\x011\n\x05\x08\x05\x12\x012*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01a\x1a&variables/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01b\x1a&variables/1/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nI\x12C\n\x0eVARIABLE_VALUE\x12\tscoped2/c\x1a&variables/2/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01"

备注

基于名称的检查点中的键是变量的名称。基于对象的检查点中的键指向从根对象到变量的路径。

查看 tf2-ckpt 中的键,可以看出它们全部指向每个变量的对象路径。

仔细研究下面的打印信息:

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])
root type = Checkpoint
root.variables = ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>])
root.variables[0] = <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>

尝试使用下面的代码段,看看检查点键如何随对象结构变化:

module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, 
                                c=c,
                                module=module)
test_ckpt_path = test_ckpt.save(f'{temp_dir}/root-tf2-ckpt')
print_checkpoint(test_ckpt_path, tab_size=25)
检查点: .temp/root-tf2-ckpt-1
key                      	shape                   	dtype                   	value                   
===============================================================================================================================================================================
v/b/.ATTRIBUTES/VARIABLE_VALUE	[]                      	float32                 	0.0                     
v/a/.ATTRIBUTES/VARIABLE_VALUE	[]                      	float32                 	0.0                     
module/d/.ATTRIBUTES/VARIABLE_VALUE	[]                      	float32                 	0.0                     
save_counter/.ATTRIBUTES/VARIABLE_VALUE	[]                      	int64                   	1                       
c/.ATTRIBUTES/VARIABLE_VALUE	[]                      	float32                 	0.0                     
_CHECKPOINTABLE_OBJECT_GRAPH	[]                      	string                  	b"\n0\n\x05\x08\x01\x12\x01c\n\n\x08\x02\x12\x06module\n\x05\x08\x03\x12\x01v\n\x10\x08\x04\x12\x0csave_counter*\x02\x08\x01\n>\x128\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1cc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n\x0b\n\x05\x08\x05\x12\x01d*\x02\x08\x01\n\x12\n\x05\x08\x06\x12\x01a\n\x05\x08\x07\x12\x01b*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nE\x12?\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a#module/d/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01"

下面是对不同模型使用相同检查点的示例。

  1. 使用 tf1.train.Saver 保存 TF1 检查点:

with tf.Graph().as_default() as g:
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                         initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.uint8, 
                         initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.uint8, 
                         initializer=tf1.zeros_initializer())
    with tf1.Session() as sess:
        saver = tf1.train.Saver()
        sess.run(a.assign(1))
        sess.run(b.assign(2))
        sess.run(c.assign(3))
        saver.save(sess, f'{temp_dir}/tf1-ckpt')

print_checkpoint(f'{temp_dir}/tf1-ckpt')
检查点: .temp/tf1-ckpt
key       	shape              	dtype    	value    
======================================================================
scoped/c  	[]                 	uint8    	3        
b         	[]                 	uint8    	2        
a         	[]                 	float32  	1.0      
  1. 使用 tf.compat.v1.Saver 在 Eager 模式下加载检查点:

a = tf.Variable(0, name="a", dtype=tf.float32)
b = tf.Variable(0, name="b", dtype=tf.uint8)
with tf.name_scope('scoped'):
    c = tf.Variable(0, name='c', dtype=tf.uint8)

# 在 TF2 中删除集合后,必须将变量列表传递给 Saver 对象:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=f'{temp_dir}/tf1-ckpt')
print(f"加载后的值 [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]")
# Saving 也可以立即执行(sess 必须为 None)。
path = saver.save(sess=None, save_path=f'{temp_dir}/tf1-ckpt-saved-in-eager')
print_checkpoint(path)
加载后的值 [a, b, c]:  [1.0, 2, 3]
检查点: .temp/tf1-ckpt-saved-in-eager
key       	shape              	dtype    	value    
======================================================================
scoped/c  	[]                 	uint8    	3        
b         	[]                 	uint8    	2        
a         	[]                 	float32  	1.0      

使用 TF2 API tf.train.Checkpoint 加载检查点:

a = tf.Variable(0, name="a", dtype=tf.float32)
b = tf.Variable(0, name="b", dtype=tf.uint8)
with tf.name_scope('scoped'):
    c = tf.Variable(0, name='c', dtype=tf.uint8)

# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0, name='scoped/c', dtype=tf.uint8)

print("变量名称: ")
print(f"\ta.name = {a.name}")
print(f"\tb.name = {b.name}")
print(f"\tc.name = {c.name}")
print(f"\tc_2.name = {c_2.name}")

# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(f'{temp_dir}/tf1-ckpt')
print(f"加载后的值 [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")
变量名称: 
	a.name = a:0
	b.name = b:0
	c.name = scoped/c:0
	c_2.name = scoped/c:0
加载后的值 [a, b, c, c_2]:  [1.0, 2, 3, 3]

SavedModel#

import shutil

def remove_dir(path):
    try:
        shutil.rmtree(path)
    except:
        ...

定义简单运算:

def add_two(x):
    return x + 2

TensorFlow 1:保存和导出 SavedModel#

在 TensorFlow 1 中,使用 tf.compat.v1.saved_model.Buildertf.compat.v1.saved_model.simple_savetf.estimator.Estimator.export_saved_model API 来构建、保存及导出 TensorFlow 计算图和会话。

  1. 使用 SavedModelBuilder 将计算图保存为 SavedModel

model_dir = f"{temp_dir}/saved-model-builder"
remove_dir(model_dir)

with tf.Graph().as_default() as g:
    x = tf1.placeholder(tf.float32, shape=[])
    y = add_two(x)
    with tf1.Session() as sess:
        print(f"结果为 {sess.run(y, {x: 3.})}")

        # 使用 SavedModelBuilder 保持
        builder = tf1.saved_model.Builder(model_dir)
        sig_def = tf1.saved_model.predict_signature_def(
            inputs={'input': x},
            outputs={'output': y}
        )
        builder.add_meta_graph_and_variables(
            sess=sess,
            tags=[tf1.saved_model.tag_constants.SERVING],
            signature_def_map={
                tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
        })
        builder.save()
结果为 5.0
  1. 为应用构建 SavedModel

remove_dir(f"{temp_dir}/simple-save")

with tf.Graph().as_default() as g:
    x = tf1.placeholder(tf.float32, shape=[])
    y = add_two(x)
    with tf1.Session() as sess:
        print(f"结果为 {sess.run(y, {x: 3.})}")
        tf1.saved_model.simple_save(
            sess, f"{temp_dir}/simple-save",
            inputs={'input': x},
            outputs={'output': y}
        )
结果为 5.0

TensorFlow 2:保存和导出 SavedModel#

保存并导出使用 tf.Module 定义的 SavedModel

要在 TensorFlow 2 中导出模型,必须定义 tf.Moduletf.keras.Model 来保存模型的所有变量和函数。随后,可以调用 tf.saved_model.save 来创建 SavedModel。

class MyModel(tf.Module):
    @tf.function
    def __call__(self, x):
        return add_two(x)
    
model = MyModel()

@tf.function
def serving_default(x):
    return {"output": model(x)}

signature_function = serving_default.get_concrete_function(
    tf.TensorSpec(shape=[], dtype=tf.float32)
)
tf.saved_model.save(
    model, f"{temp_dir}/tf2-save",
    signatures={
        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function
    }
)

保存并导出使用 Keras 定义的 SavedModel#

用于保存和导出的 Keras API(Model.savetf.keras.models.save_model)可以从 tf.keras.Model 导出 SavedModel。

inp = tf.keras.Input(3)
out = add_two(inp)
model = tf.keras.Model(inputs=inp, outputs=out)

@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def serving_default(input):
    return {'output': model(input)}

model.save(
    f"{temp_dir}/keras-model", save_format='tf', 
    signatures={tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default}
)

加载 SavedModel#

TensorFlow 1:使用 tf.saved_model.load 加载 SavedModel#

在 TensorFlow 1 中,可以使用 tf.saved_model.load 将 SavedModel 直接导入当前计算图和会话。可以在张量输入和输出名称上调用 Session.run

def load_tf1(path, x):
    print(f"加载 {path}")
    with tf.Graph().as_default() as g:
        with tf1.Session() as sess:
            meta_graph = tf1.saved_model.load(sess, ["serve"], path)
            sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
            input_name = sig_def.inputs['input'].name
            output_name = sig_def.outputs['output'].name
            print(x, '=>', sess.run(output_name, feed_dict={input_name: x}))
load_tf1(f'{temp_dir}/saved-model-builder', 5.)
load_tf1(f'{temp_dir}/simple-save', 5.)
load_tf1(f'{temp_dir}/keras-model', 5.)
加载 .temp/saved-model-builder
5.0 => 7.0
加载 .temp/simple-save
5.0 => 7.0
加载 .temp/keras-model
5.0 => 7.0

TensorFlow 2:加载使用 tf.saved_model 保存的模型#

在 TensorFlow 2 中,对象会加载到存储变量和函数的 Python 对象中。这与从 TensorFlow 1 保存的模型兼容。

def load_tf2(path, x):
    print(f"加载 {path}")
    loaded = tf.saved_model.load(path)
    sig_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    out = loaded.signatures[sig_key](tf.constant(x))['output']
    print(x, '=>', out)
load_tf2(f'{temp_dir}/saved-model-builder', 5.)
load_tf2(f'{temp_dir}/simple-save', 5.)
load_tf2(f'{temp_dir}/tf2-save', 5.)
load_tf2(f'{temp_dir}/keras-model', 5.)
加载 .temp/saved-model-builder
5.0 => tf.Tensor(7.0, shape=(), dtype=float32)
加载 .temp/simple-save
5.0 => tf.Tensor(7.0, shape=(), dtype=float32)
加载 .temp/tf2-save
5.0 => tf.Tensor(7.0, shape=(), dtype=float32)
加载 .temp/keras-model
5.0 => tf.Tensor(7.0, shape=(), dtype=float32)