1. 概述
深度学习框架TensorFlow可用于构建各种神经网络,包括卷积神经网络(CNN)。在本文中,我们将使用TensorFlow实现一个简单的卷积神经网络来解决MNIST手写数字识别问题,并演示如何保存和加载模型。
2. 数据集
MNIST是一个常用的手写数字识别数据集,包含60,000个训练样本和10,000个测试样本。我们使用TensorFlow的内置函数加载MNIST数据集,并将其划分为训练集和测试集:
import tensorflow as tf
# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据预处理
x_train, x_test = x_train / 255.0, x_test / 255.0
说明:我们将像素值范围从0-255归一化为0-1以提高模型的训练效果。
3. 创建模型
现在,我们可以构建卷积神经网络。 在这个例子中,我们将使用一个包含两个卷积层和两个池化层的卷积神经网络,一个全连接层和一个输出层:
# 创建模型
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
tf.keras.layers.MaxPooling2D((2,2)),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D((2,2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
说明:在输入层中,我们显式地指定了输入张量的形状,输出层使用softmax激活函数以生成10个类别的概率分布。
4. 编译模型
编译模型涉及选择优化器和损失函数。 我们将使用Adam优化器和分类交叉熵损失函数:
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
说明:在这里,我们使用sparse_categorical_crossentropy损失函数,因为我们的标签(类别)是整数而不是one-hot向量。
5. 训练模型
我们已经定义了模型并编译了它,现在可以训练它以适应MNIST数据集:
# 训练模型
model.fit(x_train[..., tf.newaxis], y_train, epochs=5, validation_data=(x_test[..., tf.newaxis], y_test))
说明:我们使用5个epoch进行训练,并将验证集的准确性作为评估指标。
6. 保存模型
在我们的模型训练完成后,我们可以将它保存到磁盘以供以后使用。 可以使用Model.save()方法保存模型,该方法将整个模型保存在一个HDF5文件中:
# 保存模型到文件
model.save('mnist_cnn.h5')
说明:我们已在当前文件夹中创建了名为mnist_cnn.h5的文件。
7. 加载模型
在某个时候,我们可能需要使用保存的模型,可以使用tf.keras.models.load_model()函数加载模型:
# 加载模型
model = tf.keras.models.load_model('mnist_cnn.h5')
8. 评估模型
接下来,我们可以使用测试集评估模型的准确性:
# 评估模型
model.evaluate(x_test[..., tf.newaxis], y_test)
说明:我们使用与训练时相同的预处理步骤对测试集进行预处理。
9. 推理样本
我们可以使用model.predict()方法对新样本进行预测,请注意,我们需要将样本reshape为(28, 28, 1)的三维张量:
import numpy as np
# 加载一张手写数字图片
sample = np.array(Image.open('sample.png').convert('L').resize((28, 28)))
sample = np.invert(sample)
sample = sample / 255.0
sample = sample[..., tf.newaxis]
# 使用模型进行推理
predictions = model.predict(np.array([sample]))
print(np.argmax(predictions))
说明:我们使用了Python Imaging Library(PIL)的Image模块加载手写数字图像,并将其灰度化、反转和归一化处理。
10. 结论
在本文中,我们使用TensorFlow实现了一个简单的卷积神经网络来解决MNIST手写数字识别问题,并演示了如何保存和加载模型。 该模型在测试集上的准确率约为99%,可以用于推理样本。