Python MNIST手写体识别详解与试练

1. 前言

手写数字识别是深度学习中最基本的例子之一,它可以帮助我们介绍很多深度学习中常用的概念,例如卷积神经网络和反向传播算法等。在本文中,我们将使用Python实现手写数字识别模型,并使用MNIST数据集进行训练和测试。在本文中,我们将解决以下问题:

什么是MNIST数据集?

如何使用Python加载MNIST数据集?

如何使用Python实现手写数字识别模型?

如何使用Python计算模型在测试集上的准确率?

2. MNIST数据集介绍

2.1 什么是MNIST数据集?

MNIST数据集是一个手写数字图像数据集,它是深度学习中最常用的数据集之一。该数据集包含60000张训练图像和10000张测试图像,每张图像都是28x28的灰度图像。每个图像都会被标记为一个0到9之间的数字。

2.2 数据集下载

在开始使用MNIST数据集之前,我们需要下载它。首先,我们需要安装Python的wget库:

!pip install wget

然后,我们可以使用以下代码下载MNIST数据集:

import wget

url_train = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'

url_train_labels = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'

url_test = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'

url_test_labels = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'

wget.download(url_train, 'train-images.gz')

wget.download(url_train_labels, 'train-labels.gz')

wget.download(url_test, 'test-images.gz')

wget.download(url_test_labels, 'test-labels.gz')

3. 实现手写数字识别模型

3.1 加载MNIST数据集

我们需要使用Python的gzip库来解压缩MNIST数据集文件,然后将其加载到NumPy数组中。以下代码用于将MNIST数据集加载到NumPy数组中:

import gzip

import numpy as np

def load_data():

with gzip.open('train-images.gz', 'rb') as f:

train_images_raw = f.read()

with gzip.open('train-labels.gz', 'rb') as f:

train_labels_raw = f.read()

with gzip.open('test-images.gz', 'rb') as f:

test_images_raw = f.read()

with gzip.open('test-labels.gz', 'rb') as f:

test_labels_raw = f.read()

# 解码

train_images = np.frombuffer(train_images_raw[16:], dtype=np.uint8).reshape(-1, 28 * 28)

train_labels = np.frombuffer(train_labels_raw[8:], dtype=np.uint8)

test_images = np.frombuffer(test_images_raw[16:], dtype=np.uint8).reshape(-1, 28 * 28)

test_labels = np.frombuffer(test_labels_raw[8:], dtype=np.uint8)

return (train_images, train_labels), (test_images, test_labels)

(train_images, train_labels), (test_images, test_labels) = load_data()

现在我们已经成功地将MNIST数据集加载到了NumPy数组中。每个训练图像都被表示为一个长度为784(28 * 28)的一维数组,每个测试图像都被表示为一个长度为784的一维数组。每个标签都是一个整数,标识该图像所代表的数字。

3.2 构建模型

在本文中,我们将使用Keras框架构建卷积神经网络模型(CNN)来对手写数字进行分类。CNN是一种在图像,视频和音频等数据上表现出色的深度学习算法。

以下是构建CNN模型的代码:

from keras.models import Sequential

from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

model = Sequential()

model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))

model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))

model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Dropout(0.25))

model.add(Flatten())

model.add(Dense(128, activation='relu'))

model.add(Dropout(0.5))

model.add(Dense(10, activation='softmax'))

我们的CNN模型由两个嵌套层组成。第一个卷积层具有32个3x3卷积核,激活函数为ReLU。第二个层是一个最大池化层,它将特征图的大小减小到原来的一半。第三和第四层分别是64个3x3卷积核的卷积层和最大池化层。第五层是一个Dropout层,它有助于减少过度拟合。第六层是一个全连接层(128个隐藏单元),激活函数为ReLU。第七层是另一个Dropout层。最后一层是另一个全连接层(10个输出单元),激活函数为softmax。

3.3 训练模型

我们使用以下代码来编译和训练我们的模型:

from keras.optimizers import Adam

model.compile(loss='categorical_crossentropy',

optimizer=Adam(lr=0.001, beta_1=0.9, beta_2=0.999),

metrics=['accuracy'])

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)

test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)

train_labels_categorical = keras.utils.to_categorical(train_labels, 10)

test_labels_categorical = keras.utils.to_categorical(test_labels, 10)

model.fit(train_images, train_labels_categorical,

batch_size=128,

epochs=8,

verbose=1,

validation_data=(test_images, test_labels_categorical))

我们使用Adam优化器和交叉熵损失函数来编译我们的模型。我们还将学习率设置为0.001,beta_1设置为0.9,beta_2设置为0.999。我们将每个训练批次大小设置为128个图像,并针对训练集对模型进行8个时期的训练。

3.4 测试和评估模型

我们可以使用以下代码来计算我们的模型在测试集上的准确率:

from sklearn.metrics import accuracy_score

test_predictions = model.predict(test_images)

test_predictions_labels = np.argmax(test_predictions, axis=1)

accuracy = accuracy_score(test_labels, test_predictions_labels)

print('Accuracy:', accuracy)

使用sklearn.metrics库中的accuracy_score函数,我们可以计算出我们的模型在测试集上的准确率。我们使用argmax函数来计算每个测试图像的预测标签,并将其与实际标签进行比较。

4. 总结

在本文中,我们成功地使用Python和Keras框架实现了一个手写数字识别模型,并使用MNIST数据集进行训练和测试。我们进一步介绍了MNIST数据集和卷积神经网络。如果您对深度学习和卷积神经网络感兴趣,我们推荐您深入阅读Keras文档和例子,以深入掌握深度学习的各个方面。

后端开发标签