1. TFRecord介绍
在介绍tf.TFRecordReader()函数的用法前,先简单介绍一下TFRecord。TFRecord是一种用于存储大规模数据集的格式,特别适用于带有许多标注的数据集。TFRecord是一种二进制文件,由tf.train.Example协议缓冲区组成。每个tf.train.Example包含一个或多个特征,而每个特征可以是一个张量、一个张量列表或一个张量列表组成的字典。
为了方便阅读,下面给出一个TFRecord文件的读取代码:
import tensorflow as tf
# 创建一个TFRecordReader实例
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['file0.tfrecord', 'file1.tfrecord', 'file2.tfrecord'])
# read方法从文件中读取一个样例,key是文件名,value是Example
_, serialized_example = reader.read(filename_queue)
# 解析Example
features = tf.parse_single_example(
serialized_example,
features={
"feature0": tf.FixedLenFeature([], dtype=tf.int64),
"feature1": tf.FixedLenFeature([], dtype=tf.int64),
"feature2": tf.FixedLenFeature([], dtype=tf.int64),
"feature3": tf.FixedLenFeature([], dtype=tf.string),
})
# 将解析出来的样例组合成batch
example = tf.train.batch(
[features["feature0"], features["feature1"], features["feature2"], features["feature3"]],
batch_size=32)
# 启动Session
sess = tf.Session()
# 启动读取文件的线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 运行计算图
try:
while not coord.should_stop():
print(sess.run(example))
except tf.errors.OutOfRangeError:
print('Done!')
finally:
# 关闭线程
coord.request_stop()
coord.join(threads)
2. TFRecordReader函数
下面解释一下tf.TFRecordReader()函数的用法。TFRecord文件是一种二进制文件,需要tf.TFRecordReader()函数将其读取为可以被TensorFlow计算图处理的形式。这个函数的参数只有一个,即为文件名字符串的Tensor。它的返回值有两个,第一个返回值代表可以被TensorFlow计算图处理的键,一般是文件名;第二个返回值代表的是TFRecord文件的内容。
2.1 读取单个样例
以下代码读取TFRecord文件中的单个样例:
import tensorflow as tf
# 创建一个TFRecordReader实例
reader = tf.TFRecordReader()
# 读取一个TFRecord文件
filename_queue = tf.train.string_input_producer(['file.tfrecord'])
# 从文件中读取一个样例
_, serialized_example = reader.read(filename_queue)
# 解析Example
features = tf.parse_single_example(
serialized=serialized_example,
features={
'feature0': tf.FixedLenFeature([], tf.int64),
'feature1': tf.FixedLenFeature([], tf.int64),
'feature2': tf.FixedLenFeature([], tf.string),
})
# 启动Session
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 开始输入队列的线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 显示解析出的结果
print(sess.run(features['feature0']))
print(sess.run(features['feature1']))
print(sess.run(features['feature2']))
# 关闭线程
coord.request_stop()
coord.join(threads)
其中,serialized_example是解析出的TFRecord文件的内容,它是一个字符串(bytes类型);features定义了解析出来的每个特征的名称和类型。
2.2 读取多个样例
以下代码读取TFRecord文件中的多个样例:
import tensorflow as tf
# 创建一个TFRecordReader实例
reader = tf.TFRecordReader()
# 读取多个TFRecord文件
filename_queue = tf.train.string_input_producer(['file0.tfrecord', 'file1.tfrecord'])
# 从文件中读取多个样例
_, serialized_example = reader.read_up_to(filename_queue, num_records=32)
# 解析Example
features = tf.parse_example(
serialized=serialized_example,
features={
'feature0': tf.FixedLenFeature([], tf.int64),
'feature1': tf.FixedLenFeature([], tf.int64),
'feature2': tf.FixedLenFeature([], tf.string),
})
# 启动Session
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 开始输入队列的线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 显示解析出的结果
print(sess.run(features['feature0']))
print(sess.run(features['feature1']))
print(sess.run(features['feature2']))
# 关闭线程
coord.request_stop()
coord.join(threads)
其中, serialized_example是解析出的TFRecord文件的内容,它是一个字符串(bytes类型)列表;features定义了解析出来的每个特征的名称和类型。
3. 总结
tf.TFRecordReader()函数的主要作用是读取TFRecord文件,并将其解析为可以被TensorFlow计算图处理的形式。在使用该函数时,需要注意文件名的形式及其传入方式,同时也要注意特征的解析和类型定义。在读取TFRecord文件时,需要先建立输入队列,因此在使用该函数前,还需要了解TensorFlow中队列的使用方法。