Tensorflow中批量读取数据的案列分析及TFRecord文件
1. Tensorflow中批量读取数据的意义
在深度学习的过程中,数据的读取和预处理通常是非常耗时占用计算资源的,因此如何高效地读取大量的数据是非常有意义的。Tensorflow作为目前应用较为广泛的深度学习框架之一,也提供了多种方式对数据进行高效读取,其中,批量读取数据是比较常用的一种方式。
2. Tensorflow中的批量读取数据
Tensorflow中有多种方式批量读取数据,比如使用Dataset API等,不过这里我们介绍一种比较常用的方式,即使用Tensorflow的队列进行批量读取数据。
使用队列读取数据的流程如下:
使用tf.train.slice_input_producer函数读入数据并生成一个队列
使用tf.train.batch或tf.train.shuffle_batch函数,从队列中批量读取数据
具体来说,首先需要先读入数据,以MNIST数据集为例:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
接着,使用tf.train.slice_input_producer函数读入数据并生成一个队列:
images_train, labels_train = mnist.train.images, mnist.train.labels
images_test, labels_test = mnist.test.images, mnist.test.labels
train_queue = tf.train.slice_input_producer([images_train, labels_train], shuffle=True)
这里我们使用tf.train.slice_input_producer函数将images_train和labels_train放入一个队列中,并且设置shuffle参数为True,表示队列中的数据顺序是随机的。
有了队列,我们就可以通过队列来批量读取数据了。使用tf.train.batch或tf.train.shuffle_batch函数进行批量读取数据:
batch_size = 100
images_batch, labels_batch = tf.train.shuffle_batch(train_queue, batch_size=batch_size, capacity=batch_size*64, num_threads=32, min_after_dequeue=batch_size*32)
这里我们使用tf.train.shuffle_batch函数从train_queue队列中读取一个batch_size大小的数据,其中,capacity表示队列的缓存容量,num_threads表示使用的线程数,min_after_dequeue表示队列除了缓存的部分外,最少保留的数据数目。这些参数需要根据实际情况进行设置。
3. 使用TFRecord文件进行批量读取数据
除了上述方式,我们还可以使用TFRecord文件进行批量读取数据,TFRecord是一种二进制数据格式,可以极大地提高数据读取的速度。下面以MNIST数据集为例,介绍如何将MNIST数据集转换为TFRecord格式,并且如何使用TFRecord文件进行批量读取数据。
3.1 将MNIST数据集转换为TFRecord格式
将MNIST数据集转换为TFRecord格式,可以使用以下代码:
def convert_to_tfrecord(images, labels, filename):
with tf.python_io.TFRecordWriter(filename) as writer:
num_examples = images.shape[0]
for i in range(num_examples):
image_raw = images[i].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
}))
writer.write(example.SerializeToString())
filename_train = 'mnist_train.tfrecords'
filename_test = 'mnist_test.tfrecords'
convert_to_tfrecord(images_train, labels_train.argmax(axis=1), filename_train)
convert_to_tfrecord(images_test, labels_test.argmax(axis=1), filename_test)
该代码可以将images和labels保存为TFRecord格式的文件,其中labels采用one-hot编码。通过上述代码,会保存两个文件:'mnist_train.tfrecords'和'mnist_test.tfrecords'。
3.2 使用TFRecord文件进行批量读取数据
有了TFRecord文件,我们就可以使用tf.data.TFRecordDataset函数进行批量读取数据了,具体来说,代码如下:
batch_size = 100
dataset_train = tf.data.TFRecordDataset(['mnist_train.tfrecords'])
dataset_train = dataset_train.map(lambda x: tf.parse_single_example(x, features={
'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string),
})).map(lambda x: (tf.decode_raw(x['image'], tf.float32), tf.cast(x['label'], tf.int32)))
dataset_train = dataset_train.shuffle(buffer_size=10000).batch(batch_size).repeat()
iterator_train = dataset_train.make_initializable_iterator()
x_train, y_train = iterator_train.get_next()
在上述代码中,通过使用tf.data.TFRecordDataset函数,将'mnist_train.tfrecords'文件转换为数据集,同时,使用map函数将tf.string类型的'image'解码为tf.float32类型的张量,将tf.int64类型的'label'转换为tf.int32类型的张量。之后,对数据集进行shuffle、batch和repeat操作,最后,通过make_initializable_iterator函数创建数据集迭代器,并且调用get_next函数获取下一个batch的数据。
总结
以上介绍了Tensorflow中批量读取数据的方案,其中,使用队列进行批量读取数据和使用TFRecord文件进行批量读取数据,都是常用的方式。具体使用哪种方式,需要根据实际情况来进行选择。