Tensorflow中批量读取数据的案列分析及TFRecord文件

Tensorflow中批量读取数据的案列分析及TFRecord文件

1. Tensorflow中批量读取数据的意义

在深度学习的过程中,数据的读取和预处理通常是非常耗时占用计算资源的,因此如何高效地读取大量的数据是非常有意义的。Tensorflow作为目前应用较为广泛的深度学习框架之一,也提供了多种方式对数据进行高效读取,其中,批量读取数据是比较常用的一种方式。

2. Tensorflow中的批量读取数据

Tensorflow中有多种方式批量读取数据,比如使用Dataset API等,不过这里我们介绍一种比较常用的方式,即使用Tensorflow的队列进行批量读取数据。

使用队列读取数据的流程如下:

使用tf.train.slice_input_producer函数读入数据并生成一个队列

使用tf.train.batch或tf.train.shuffle_batch函数,从队列中批量读取数据

具体来说,首先需要先读入数据,以MNIST数据集为例:

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

接着,使用tf.train.slice_input_producer函数读入数据并生成一个队列:

images_train, labels_train = mnist.train.images, mnist.train.labels

images_test, labels_test = mnist.test.images, mnist.test.labels

train_queue = tf.train.slice_input_producer([images_train, labels_train], shuffle=True)

这里我们使用tf.train.slice_input_producer函数将images_train和labels_train放入一个队列中,并且设置shuffle参数为True,表示队列中的数据顺序是随机的。

有了队列,我们就可以通过队列来批量读取数据了。使用tf.train.batch或tf.train.shuffle_batch函数进行批量读取数据:

batch_size = 100

images_batch, labels_batch = tf.train.shuffle_batch(train_queue, batch_size=batch_size, capacity=batch_size*64, num_threads=32, min_after_dequeue=batch_size*32)

这里我们使用tf.train.shuffle_batch函数从train_queue队列中读取一个batch_size大小的数据,其中,capacity表示队列的缓存容量,num_threads表示使用的线程数,min_after_dequeue表示队列除了缓存的部分外,最少保留的数据数目。这些参数需要根据实际情况进行设置。

3. 使用TFRecord文件进行批量读取数据

除了上述方式,我们还可以使用TFRecord文件进行批量读取数据,TFRecord是一种二进制数据格式,可以极大地提高数据读取的速度。下面以MNIST数据集为例,介绍如何将MNIST数据集转换为TFRecord格式,并且如何使用TFRecord文件进行批量读取数据。

3.1 将MNIST数据集转换为TFRecord格式

将MNIST数据集转换为TFRecord格式,可以使用以下代码:

def convert_to_tfrecord(images, labels, filename):

with tf.python_io.TFRecordWriter(filename) as writer:

num_examples = images.shape[0]

for i in range(num_examples):

image_raw = images[i].tostring()

example = tf.train.Example(features=tf.train.Features(feature={

'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]])),

'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),

}))

writer.write(example.SerializeToString())

filename_train = 'mnist_train.tfrecords'

filename_test = 'mnist_test.tfrecords'

convert_to_tfrecord(images_train, labels_train.argmax(axis=1), filename_train)

convert_to_tfrecord(images_test, labels_test.argmax(axis=1), filename_test)

该代码可以将images和labels保存为TFRecord格式的文件,其中labels采用one-hot编码。通过上述代码,会保存两个文件:'mnist_train.tfrecords'和'mnist_test.tfrecords'。

3.2 使用TFRecord文件进行批量读取数据

有了TFRecord文件,我们就可以使用tf.data.TFRecordDataset函数进行批量读取数据了,具体来说,代码如下:

batch_size = 100

dataset_train = tf.data.TFRecordDataset(['mnist_train.tfrecords'])

dataset_train = dataset_train.map(lambda x: tf.parse_single_example(x, features={

'label': tf.FixedLenFeature([], tf.int64),

'image': tf.FixedLenFeature([], tf.string),

})).map(lambda x: (tf.decode_raw(x['image'], tf.float32), tf.cast(x['label'], tf.int32)))

dataset_train = dataset_train.shuffle(buffer_size=10000).batch(batch_size).repeat()

iterator_train = dataset_train.make_initializable_iterator()

x_train, y_train = iterator_train.get_next()

在上述代码中,通过使用tf.data.TFRecordDataset函数,将'mnist_train.tfrecords'文件转换为数据集,同时,使用map函数将tf.string类型的'image'解码为tf.float32类型的张量,将tf.int64类型的'label'转换为tf.int32类型的张量。之后,对数据集进行shuffle、batch和repeat操作,最后,通过make_initializable_iterator函数创建数据集迭代器,并且调用get_next函数获取下一个batch的数据。

总结

以上介绍了Tensorflow中批量读取数据的方案,其中,使用队列进行批量读取数据和使用TFRecord文件进行批量读取数据,都是常用的方式。具体使用哪种方式,需要根据实际情况来进行选择。

后端开发标签