Tensorflow之MNIST CNN实现并保存、加载模型

1. 概述

深度学习框架TensorFlow可用于构建各种神经网络,包括卷积神经网络(CNN)。在本文中,我们将使用TensorFlow实现一个简单的卷积神经网络来解决MNIST手写数字识别问题,并演示如何保存和加载模型。

2. 数据集

MNIST是一个常用的手写数字识别数据集,包含60,000个训练样本和10,000个测试样本。我们使用TensorFlow的内置函数加载MNIST数据集,并将其划分为训练集和测试集:

import tensorflow as tf

# 加载MNIST数据集

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据预处理

x_train, x_test = x_train / 255.0, x_test / 255.0

说明:我们将像素值范围从0-255归一化为0-1以提高模型的训练效果。

3. 创建模型

现在,我们可以构建卷积神经网络。 在这个例子中,我们将使用一个包含两个卷积层和两个池化层的卷积神经网络,一个全连接层和一个输出层:

# 创建模型

model = tf.keras.models.Sequential([

tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),

tf.keras.layers.MaxPooling2D((2,2)),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D((2,2)),

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(128, activation='relu'),

tf.keras.layers.Dense(10, activation='softmax')

])

说明:在输入层中,我们显式地指定了输入张量的形状,输出层使用softmax激活函数以生成10个类别的概率分布。

4. 编译模型

编译模型涉及选择优化器和损失函数。 我们将使用Adam优化器和分类交叉熵损失函数:

# 编译模型

model.compile(optimizer='adam',

loss='sparse_categorical_crossentropy',

metrics=['accuracy'])

说明:在这里,我们使用sparse_categorical_crossentropy损失函数,因为我们的标签(类别)是整数而不是one-hot向量。

5. 训练模型

我们已经定义了模型并编译了它,现在可以训练它以适应MNIST数据集:

# 训练模型

model.fit(x_train[..., tf.newaxis], y_train, epochs=5, validation_data=(x_test[..., tf.newaxis], y_test))

说明:我们使用5个epoch进行训练,并将验证集的准确性作为评估指标。

6. 保存模型

在我们的模型训练完成后,我们可以将它保存到磁盘以供以后使用。 可以使用Model.save()方法保存模型,该方法将整个模型保存在一个HDF5文件中:

# 保存模型到文件

model.save('mnist_cnn.h5')

说明:我们已在当前文件夹中创建了名为mnist_cnn.h5的文件。

7. 加载模型

在某个时候,我们可能需要使用保存的模型,可以使用tf.keras.models.load_model()函数加载模型:

# 加载模型

model = tf.keras.models.load_model('mnist_cnn.h5')

8. 评估模型

接下来,我们可以使用测试集评估模型的准确性:

# 评估模型

model.evaluate(x_test[..., tf.newaxis], y_test)

说明:我们使用与训练时相同的预处理步骤对测试集进行预处理。

9. 推理样本

我们可以使用model.predict()方法对新样本进行预测,请注意,我们需要将样本reshape为(28, 28, 1)的三维张量:

import numpy as np

# 加载一张手写数字图片

sample = np.array(Image.open('sample.png').convert('L').resize((28, 28)))

sample = np.invert(sample)

sample = sample / 255.0

sample = sample[..., tf.newaxis]

# 使用模型进行推理

predictions = model.predict(np.array([sample]))

print(np.argmax(predictions))

说明:我们使用了Python Imaging Library(PIL)的Image模块加载手写数字图像,并将其灰度化、反转和归一化处理。

10. 结论

在本文中,我们使用TensorFlow实现了一个简单的卷积神经网络来解决MNIST手写数字识别问题,并演示了如何保存和加载模型。 该模型在测试集上的准确率约为99%,可以用于推理样本。

免责声明:本文来自互联网,本站所有信息(包括但不限于文字、视频、音频、数据及图表),不保证该信息的准确性、真实性、完整性、有效性、及时性、原创性等,版权归属于原作者,如无意侵犯媒体或个人知识产权,请来电或致函告之,本站将在第一时间处理。猿码集站发布此文目的在于促进信息交流,此文观点与本站立场无关,不承担任何责任。

后端开发标签