浅谈tensorflow模型保存为pb的各种姿势

1. 背景介绍

TensorFlow作为目前比较流行的深度学习框架之一,广泛应用于工业界和学术界。在TensorFlow中,模型保存为pb文件格式是非常常见且重要的步骤。在实际的项目中,我们经常需要使用TensorFlow训练好的模型文件进行推理或者转移到其他平台上使用。因此,学会如何保存TensorFlow模型是至关重要的。

2. 模型保存方式

2.1 使用tf.train.Saver保存checkpoint

在TensorFlow中,最基本的保存方式就是保存checkpoint,使用tf.train.Saver()可以将模型的参数以二进制的形式保存在checkpoint文件夹中。代码示例如下:

import tensorflow as tf

# 定义模型

x = tf.placeholder(tf.float32, [None, 784], name='x')

y = tf.placeholder(tf.float32, [None, 10], name='y')

W = tf.Variable(tf.zeros([784, 10]), name='W')

b = tf.Variable(tf.zeros([10]), name='b')

y_pred = tf.nn.softmax(tf.matmul(x, W) + b, name='y_pred')

# 定义损失函数和优化器

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_pred), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# 创建Saver对象

saver = tf.train.Saver()

with tf.Session() as sess:

# 训练模型

sess.run(tf.global_variables_initializer())

for i in range(1000):

batch_xs, batch_ys = ... # 准备数据

sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})

# 保存checkpoint文件

saver.save(sess, 'model.ckpt')

这种方式保存的模型包含了模型的参数值,但是不包含模型的计算图结构,因此需要在恢复模型时重新创建计算图。

2.2 使用tf.saved_model保存pb文件

TensorFlow提供了一种更加通用的保存模型的方式——saved model。使用tf.saved_model.builder.SavedModelBuilder可以保存完整的模型计算图和参数值,并将其存储在一个pb文件夹中,以供后续使用。代码示例如下:

import tensorflow as tf

# 定义模型

x = tf.placeholder(tf.float32, [None, 784], name='x')

y = tf.placeholder(tf.float32, [None, 10], name='y')

W = tf.Variable(tf.zeros([784, 10]), name='W')

b = tf.Variable(tf.zeros([10]), name='b')

y_pred = tf.nn.softmax(tf.matmul(x, W) + b, name='y_pred')

# 定义损失函数和优化器

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_pred), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# 创建SavedModelBuilder对象

builder = tf.saved_model.builder.SavedModelBuilder('saved_model')

with tf.Session() as sess:

# 训练模型

sess.run(tf.global_variables_initializer())

for i in range(1000):

batch_xs, batch_ys = ... # 准备数据

sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})

# 保存模型为saved model格式

# 定义输入和输出的tensor名字到signature map

inputs = {'x': tf.saved_model.utils.build_tensor_info(x)}

outputs = {'y_pred': tf.saved_model.utils.build_tensor_info(y_pred)}

signature = tf.saved_model.signature_def_utils.build_signature_def(

inputs=inputs,

outputs=outputs,

method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

builder.add_meta_graph_and_variables(

sess,

[tf.saved_model.tag_constants.SERVING],

signature_def_map={'predict': signature})

builder.save()

这种方式保存的模型不仅包含了模型的参数值,还包含了完整的模型计算图结构及其输入输出张量的相关信息,使得后续恢复模型更加方便。

2.3 使用freeze_graph.py保存pb文件

除了使用TensorFlow自带的保存方式,我们还可以使用外部工具来保存TensorFlow模型。其中,比较常见的工具是freeze_graph.py。该工具可以将ckpt文件和meta文件一起转换为包含计算图和参数的pb文件。freeze_graph.py的使用方法如下:

python freeze_graph.py \

--input_checkpoint=model.ckpt \

--output_graph=model.pb \

--output_node_names=y_pred \

--input_binary=true

其中,input_checkpoint指定ckpt模型文件的路径,output_graph指定输出的pb文件的路径,output_node_names指定用于推理的模型输出节点名称,input_binary指定是否将ckpt文件转化为二进制形式。

3. 模型恢复方式

3.1 恢复checkpoint文件

对于使用tf.train.Saver保存的checkpoint文件,在恢复模型时需要首先恢复模型参数并重新创建计算图。具体步骤如下:

import tensorflow as tf

# 创建Saver对象

saver = tf.train.Saver()

with tf.Session() as sess:

# 恢复模型参数

saver.restore(sess, 'model.ckpt')

# 创建计算图

x = tf.placeholder(tf.float32, [None, 784], name='x')

y = tf.placeholder(tf.float32, [None, 10], name='y')

W = tf.Variable(tf.zeros([784, 10]), name='W')

b = tf.Variable(tf.zeros([10]), name='b')

y_pred = tf.nn.softmax(tf.matmul(x, W) + b, name='y_pred')

# 调用模型进行推理

y_pred_value = sess.run(y_pred, feed_dict={x: batch_xs})

这种方式需要提前定义计算图结构,适用于使用TensorFlow自带的保存方式保存的模型。

3.2 恢复saved model格式的模型

使用tf.saved_model.loader.load函数可以轻松地恢复saved model格式的模型。具体步骤如下:

import tensorflow as tf

# 加载模型

with tf.Session() as sess:

meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 'saved_model')

signature = meta_graph_def.signature_def['predict']

x_tensor_name = signature.inputs['x'].name

y_pred_tensor_name = signature.outputs['y_pred'].name

x = sess.graph.get_tensor_by_name(x_tensor_name)

y_pred = sess.graph.get_tensor_by_name(y_pred_tensor_name)

# 调用模型进行推理

y_pred_value = sess.run(y_pred, feed_dict={x: batch_xs})

这种方式可以直接获取模型的输入输出张量,不需要提前定义计算图结构,更加便捷。

3.3 恢复freeze_graph.py生成的pb文件

加载pb文件的方法与加载saved model格式的方法类似,不过需要使用tf.gfile.GFile读取文件内容。具体步骤如下:

import tensorflow as tf

# 读取pb文件

with tf.gfile.GFile('model.pb', 'rb') as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

# 创建计算图,并导入pb文件

with tf.Graph().as_default() as graph:

tf.import_graph_def(graph_def, name='')

x = graph.get_tensor_by_name('x:0')

y_pred = graph.get_tensor_by_name('y_pred:0')

with tf.Session(graph=graph) as sess:

# 调用模型进行推理

y_pred_value = sess.run(y_pred, feed_dict={x: batch_xs})

这种方式也是适用于TensorFlow自带以外的保存方式的,比较灵活,但需要提前知道输入输出张量的名称。

4. 总结

TensorFlow提供了多种保存和恢复模型的方式,可以根据具体的需求选择不同的方式。在实际的项目中,为了保证模型的正确性和可复用性,建议使用saved model格式的方式来保存模型。在使用TensorFlow模型时,需要根据具体的模型类型和保存方式选择恰当的加载方法,并注意输入输出张量的名称和类型,以免出现错误。

免责声明:本文来自互联网,本站所有信息(包括但不限于文字、视频、音频、数据及图表),不保证该信息的准确性、真实性、完整性、有效性、及时性、原创性等,版权归属于原作者,如无意侵犯媒体或个人知识产权,请来电或致函告之,本站将在第一时间处理。猿码集站发布此文目的在于促进信息交流,此文观点与本站立场无关,不承担任何责任。

后端开发标签