1. 背景介绍
TensorFlow作为目前比较流行的深度学习框架之一,广泛应用于工业界和学术界。在TensorFlow中,模型保存为pb文件格式是非常常见且重要的步骤。在实际的项目中,我们经常需要使用TensorFlow训练好的模型文件进行推理或者转移到其他平台上使用。因此,学会如何保存TensorFlow模型是至关重要的。
2. 模型保存方式
2.1 使用tf.train.Saver保存checkpoint
在TensorFlow中,最基本的保存方式就是保存checkpoint,使用tf.train.Saver()可以将模型的参数以二进制的形式保存在checkpoint文件夹中。代码示例如下:
import tensorflow as tf
# 定义模型
x = tf.placeholder(tf.float32, [None, 784], name='x')
y = tf.placeholder(tf.float32, [None, 10], name='y')
W = tf.Variable(tf.zeros([784, 10]), name='W')
b = tf.Variable(tf.zeros([10]), name='b')
y_pred = tf.nn.softmax(tf.matmul(x, W) + b, name='y_pred')
# 定义损失函数和优化器
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_pred), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 创建Saver对象
saver = tf.train.Saver()
with tf.Session() as sess:
# 训练模型
sess.run(tf.global_variables_initializer())
for i in range(1000):
batch_xs, batch_ys = ... # 准备数据
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
# 保存checkpoint文件
saver.save(sess, 'model.ckpt')
这种方式保存的模型包含了模型的参数值,但是不包含模型的计算图结构,因此需要在恢复模型时重新创建计算图。
2.2 使用tf.saved_model保存pb文件
TensorFlow提供了一种更加通用的保存模型的方式——saved model。使用tf.saved_model.builder.SavedModelBuilder可以保存完整的模型计算图和参数值,并将其存储在一个pb文件夹中,以供后续使用。代码示例如下:
import tensorflow as tf
# 定义模型
x = tf.placeholder(tf.float32, [None, 784], name='x')
y = tf.placeholder(tf.float32, [None, 10], name='y')
W = tf.Variable(tf.zeros([784, 10]), name='W')
b = tf.Variable(tf.zeros([10]), name='b')
y_pred = tf.nn.softmax(tf.matmul(x, W) + b, name='y_pred')
# 定义损失函数和优化器
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_pred), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 创建SavedModelBuilder对象
builder = tf.saved_model.builder.SavedModelBuilder('saved_model')
with tf.Session() as sess:
# 训练模型
sess.run(tf.global_variables_initializer())
for i in range(1000):
batch_xs, batch_ys = ... # 准备数据
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
# 保存模型为saved model格式
# 定义输入和输出的tensor名字到signature map
inputs = {'x': tf.saved_model.utils.build_tensor_info(x)}
outputs = {'y_pred': tf.saved_model.utils.build_tensor_info(y_pred)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={'predict': signature})
builder.save()
这种方式保存的模型不仅包含了模型的参数值,还包含了完整的模型计算图结构及其输入输出张量的相关信息,使得后续恢复模型更加方便。
2.3 使用freeze_graph.py保存pb文件
除了使用TensorFlow自带的保存方式,我们还可以使用外部工具来保存TensorFlow模型。其中,比较常见的工具是freeze_graph.py。该工具可以将ckpt文件和meta文件一起转换为包含计算图和参数的pb文件。freeze_graph.py的使用方法如下:
python freeze_graph.py \
--input_checkpoint=model.ckpt \
--output_graph=model.pb \
--output_node_names=y_pred \
--input_binary=true
其中,input_checkpoint指定ckpt模型文件的路径,output_graph指定输出的pb文件的路径,output_node_names指定用于推理的模型输出节点名称,input_binary指定是否将ckpt文件转化为二进制形式。
3. 模型恢复方式
3.1 恢复checkpoint文件
对于使用tf.train.Saver保存的checkpoint文件,在恢复模型时需要首先恢复模型参数并重新创建计算图。具体步骤如下:
import tensorflow as tf
# 创建Saver对象
saver = tf.train.Saver()
with tf.Session() as sess:
# 恢复模型参数
saver.restore(sess, 'model.ckpt')
# 创建计算图
x = tf.placeholder(tf.float32, [None, 784], name='x')
y = tf.placeholder(tf.float32, [None, 10], name='y')
W = tf.Variable(tf.zeros([784, 10]), name='W')
b = tf.Variable(tf.zeros([10]), name='b')
y_pred = tf.nn.softmax(tf.matmul(x, W) + b, name='y_pred')
# 调用模型进行推理
y_pred_value = sess.run(y_pred, feed_dict={x: batch_xs})
这种方式需要提前定义计算图结构,适用于使用TensorFlow自带的保存方式保存的模型。
3.2 恢复saved model格式的模型
使用tf.saved_model.loader.load函数可以轻松地恢复saved model格式的模型。具体步骤如下:
import tensorflow as tf
# 加载模型
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 'saved_model')
signature = meta_graph_def.signature_def['predict']
x_tensor_name = signature.inputs['x'].name
y_pred_tensor_name = signature.outputs['y_pred'].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y_pred = sess.graph.get_tensor_by_name(y_pred_tensor_name)
# 调用模型进行推理
y_pred_value = sess.run(y_pred, feed_dict={x: batch_xs})
这种方式可以直接获取模型的输入输出张量,不需要提前定义计算图结构,更加便捷。
3.3 恢复freeze_graph.py生成的pb文件
加载pb文件的方法与加载saved model格式的方法类似,不过需要使用tf.gfile.GFile读取文件内容。具体步骤如下:
import tensorflow as tf
# 读取pb文件
with tf.gfile.GFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 创建计算图,并导入pb文件
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
x = graph.get_tensor_by_name('x:0')
y_pred = graph.get_tensor_by_name('y_pred:0')
with tf.Session(graph=graph) as sess:
# 调用模型进行推理
y_pred_value = sess.run(y_pred, feed_dict={x: batch_xs})
这种方式也是适用于TensorFlow自带以外的保存方式的,比较灵活,但需要提前知道输入输出张量的名称。
4. 总结
TensorFlow提供了多种保存和恢复模型的方式,可以根据具体的需求选择不同的方式。在实际的项目中,为了保证模型的正确性和可复用性,建议使用saved model格式的方式来保存模型。在使用TensorFlow模型时,需要根据具体的模型类型和保存方式选择恰当的加载方法,并注意输入输出张量的名称和类型,以免出现错误。