1. 前言
在深度学习中,模型的保存与加载是非常重要的功能之一。当我们训练好一个模型后,通常需要将其保存下来以备后续使用。TensorFlow提供了一种保存模型的方式,即使用tf.train.Saver()
对象将模型的变量保存成checkpoint文件格式(或者保存为.pb文件)。
2. tensorflow模型的保存与加载
2.1 模型保存
要将训练好的模型保存下来,我们需要先创建一个tf.train.Saver()
对象,并调用其save()
方法传入相关的session和保存路径。
import tensorflow as tf
# 创建模型
# ...
# 创建Saver对象
saver = tf.train.Saver()
# 保存模型
save_path = saver.save(session, "model.ckpt")
print("Model saved in file: %s" % save_path)
在上述代码中,我们首先创建了一个Saver
对象,然后调用了save()
方法将模型保存在了model.ckpt
文件中。
2.2 模型加载
要加载之前保存的模型,我们同样需要创建一个Saver
对象,并调用其restore()
方法传入相关的session和保存路径。
import tensorflow as tf
# 创建模型
# ...
# 创建Saver对象
saver = tf.train.Saver()
# 加载模型
saver.restore(session, "model.ckpt")
print("Model restored.")
在上述代码中,我们首先创建了一个Saver
对象,然后调用了restore()
方法从model.ckpt
文件中恢复模型。
3. 获取ckpt模型中的变量名称与变量
在TensorFlow中,我们可以使用tf.train.NewCheckpointReader
来读取ckpt模型并获取其中的变量名称与变量。下面是一个示例:
import tensorflow as tf
reader = tf.train.NewCheckpointReader("model.ckpt")
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("Variable name: ", key)
print("Numpy array: ", reader.get_tensor(key))
上述代码中,我们首先创建了一个NewCheckpointReader
对象,并传入模型的路径"model.ckpt"
。然后,我们可以使用get_variable_to_shape_map()
方法来获取ckpt模型中的变量名称与形状。最后,我们可以通过get_tensor()
方法来获取各个变量的值。
4. 示例:打印ckpt模型保存下的变量名称与变量
下面通过一个完整的示例来演示如何打印ckpt模型保存下的变量名称与变量。
import tensorflow as tf
# 创建模型
x = tf.Variable(tf.random_normal([1]), name="x")
y = tf.Variable(tf.random_normal([1]), name="y")
z = tf.add(x, y, name="z")
# 创建Saver对象
saver = tf.train.Saver()
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
print("Variables before saving:")
for var in tf.global_variables():
print(var.name)
# 保存模型
save_path = saver.save(sess, "model.ckpt")
print("Model saved in file: %s" % save_path)
print("Variables after saving:")
reader = tf.train.NewCheckpointReader("model.ckpt")
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("Variable name: ", key)
print("Numpy array: ", reader.get_tensor(key))
在上述代码中,我们首先创建了三个变量x
、y
和z
,然后创建了一个Saver
对象。在with tf.Session() as sess:
中,我们初始化变量并打印保存前的变量名称,接着保存模型,并打印保存后的变量名称和变量值。
5. 小结
通过tf.train.Saver()
对象,我们可以方便地保存和加载TensorFlow模型。而通过tf.train.NewCheckpointReader
对象,我们可以读取ckpt模型中的变量名称和变量值。这些功能在模型的训练和部署过程中都非常实用。
本文基于TensorFlow 1.x版本进行讲解,相关代码可在TensorFlow官方文档中找到。