tensorflow从ckpt和从.pb文件读取变量的值方式

1. 介绍

在使用TensorFlow进行深度学习模型训练和部署时,我们常常需要保存和加载模型的变量。TensorFlow提供了两种方式来保存模型的变量值,分别是从ckpt文件和从pb文件读取。ckpt文件保存了模型的所有变量值,而pb文件保存了模型的计算图和变量值。在本文中,我们将详细介绍如何从ckpt和pb文件读取变量的值,并对比两者的区别和优缺点。

2. 从ckpt文件读取变量的值

2.1 保存ckpt文件

在使用TensorFlow训练模型时,可以使用tf.train.Saver来保存模型的变量。以下是一个保存模型变量的示例代码:

import tensorflow as tf

# 定义模型

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

# 定义损失函数和优化器

y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 创建Saver对象

saver = tf.train.Saver()

# 训练模型

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

saver.save(sess, "model.ckpt")

运行上述代码后,会生成一个ckpt文件,文件名为model.ckpt。

2.2 读取ckpt文件

要从ckpt文件中读取变量的值,可以使用tf.train.Saver的restore方法。以下是一个读取ckpt文件并获取变量值的示例代码:

import tensorflow as tf

# 定义模型

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

# 创建Saver对象

saver = tf.train.Saver()

# 读取ckpt文件

with tf.Session() as sess:

saver.restore(sess, "model.ckpt")

W_value, b_value = sess.run([W, b])

# 打印变量值

print("W:", W_value)

print("b:", b_value)

运行上述代码后,会输出模型变量W和b的值。

3. 从pb文件读取变量的值

3.1 保存pb文件

TensorFlow提供了tf.summary.FileWriter来保存模型的计算图和变量值为pb文件。以下是一个保存pb文件的示例代码:

import tensorflow as tf

# 定义模型

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

# 创建Saver对象

saver = tf.train.Saver()

# 训练模型

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

saver.save(sess, "model.ckpt")

tf.train.write_graph(sess.graph.as_graph_def(), ".", "model.pb", as_text=False)

运行上述代码后,会生成一个pb文件,文件名为model.pb。

3.2 加载pb文件

要从pb文件中读取变量的值,首先需要加载计算图。以下是一个加载pb文件并获取变量值的示例代码:

import tensorflow as tf

# 加载计算图

def load_graph(pb_file):

with tf.gfile.GFile(pb_file, "rb") as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

tf.import_graph_def(graph_def)

# 获取变量值

def get_variable_value(pb_file, variable_name):

with tf.Session() as sess:

load_graph(pb_file)

var = tf.get_default_graph().get_tensor_by_name(variable_name)

return sess.run(var)

# 读取pb文件并获取变量值

W_value = get_variable_value("model.pb", "Variable:0")

b_value = get_variable_value("model.pb", "Variable_1:0")

# 打印变量值

print("W:", W_value)

print("b:", b_value)

运行上述代码后,会输出模型变量W和b的值。

4. 对比和总结

从ckpt文件和pb文件读取变量的值,两者的方法和步骤略有不同,但都可以实现获取变量值的目的。以下是两者的对比和总结:

从ckpt文件读取变量的值:

保存和读取变量的过程比较简单,只需要使用tf.train.Saver类的save和restore方法即可。

ckpt文件保存了模型的所有变量值,文件较大,但加载速度比较快。

从pb文件读取变量的值:

保存和读取变量的过程相对复杂,需要先保存计算图,然后加载计算图,最后获取变量值。

pb文件保存了模型的计算图和变量值,文件较小,但加载速度比较慢。

综上所述,从ckpt文件读取变量的值相对简单且加载速度较快,适合单独使用;而从pb文件读取变量的值相对复杂且加载速度较慢,适合用于模型部署和推理。

在实际应用中,根据具体情况选择使用ckpt文件还是pb文件来保存和加载模型的变量值。

后端开发标签