pytorch实现MNIST手写体识别

1. 简介

手写数字识别一直是计算机视觉领域的一个基础问题。在这篇文章中,我们将介绍如何使用pytorch实现MNIST手写数字识别。MNIST是一个手写数字的数据集,包括60000个训练样本和10000个测试样本。每个样本都是28x28的灰度图像,标记为0到9中的一个数字。

2. 数据准备

2.1 数据下载

我们使用pytorch内置函数torchvision.datasets.MNIST下载MNIST。以下代码片段演示如何下载和准备数据:

import torch

import torchvision

import torchvision.transforms as transforms

transform=transforms.Compose([transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))])

trainset=torchvision.datasets.MNIST(root='./data',

train=True,

download=True,

transform=transform)

trainloader=torch.utils.data.DataLoader(trainset,

batch_size=4,

shuffle=True,

num_workers=2)

testset=torchvision.datasets.MNIST(root='./data',

train=False,

download=True,

transform=transform)

testloader=torch.utils.data.DataLoader(testset,

batch_size=4,

shuffle=False,

num_workers=2)

我们将数据分成两个部分:训练数据和测试数据。对于训练数据,我们使用torchvision.datasets.MNIST函数从网站中下载MNIST数据集。对于测试数据,我们设置train=False,代表我们要下载测试数据。

我们使用torchvision.transforms.Compose函数来定义需要进行的数据转换。在上述代码中,我们只使用了两个转换:将图像转换为张量和使用正态分布的均值和标准差来标准化图像。最后,我们使用torch.utils.data.DataLoader函数来生成可以迭代的数据集。

2.2 可视化数据

我们可以使用一些Python库来可视化手写数字。以下是使用Matplotlib可视化MNIST数据集中的前几张图像的代码:

import matplotlib.pyplot as plt

import numpy as np

# 定义一个可迭代的数据加载器

dataiter=iter(trainloader)

images, labels=dataiter.next()

# 显示四个图像和它们的标签

fig=plt.figure(figsize=(8, 8))

for i in range(4):

ax=fig.add_subplot(2, 2, i+1)

ax.imshow(np.squeeze(images[i].numpy()), cmap='gray')

ax.set_title(str(labels[i].item()))

我们使用iter(trainloader)将数据加载到一个可迭代的数据加载器中。接下来,我们使用dataiter.next()从可迭代的数据加载器中取出4个样本。最后,我们使用matplotlib.pyplot库将这些图像可视化。

3. 构建模型

对于MNIST手写数字识别问题,我们可以使用一个全连接神经网络来解决。以下是在pytorch中定义这个模型的代码:

import torch.nn as nn

import torch.nn.functional as F

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()

self.fc1=nn.Linear(28*28, 512)

self.fc2=nn.Linear(512, 256)

self.fc3=nn.Linear(256, 128)

self.fc4=nn.Linear(128, 10)

self.dropout=nn.Dropout(0.2)

def forward(self, x):

x=x.view(-1, 28*28)

x=F.relu(self.fc1(x))

x=self.dropout(x)

x=F.relu(self.fc2(x))

x=self.dropout(x)

x=F.relu(self.fc3(x))

x=self.dropout(x)

x=self.fc4(x)

return x

net=Net()

在这个模型中,我们使用4个全连接层,并添加了一个dropout层以减少过拟合。在前向传递期间,我们将输入张量展平为(batch_size,28 * 28)的形状,并调用ReLU函数来计算每个层的激活函数。最后,我们将输出传递给最后一层,其大小为10,代表我们希望对每个数字预测一个概率。

4. 训练模型

我们现在已准备就绪,并且已构建了神经网络。接下来,我们需要使用训练数据对模型进行训练。在我们开始之前,我们需要定义模型的损失函数和优化函数。以下是在pytorch中定义这些内容的代码:

import torch.optim as optim

criterion=nn.CrossEntropyLoss()

optimizer=optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

我们使用交叉熵损失和随机梯度下降(SGD)优化算法。在每个迭代中,我们将使用梯度下降来更新神经网络中的所有权重,并尽可能少地预测错误的标签。以下是一个代表训练循环的代码片段:

for epoch in range(2):

running_loss=0.0

for i, data in enumerate(trainloader, 0):

inputs, labels=data

optimizer.zero_grad()

outputs=net(inputs)

loss=criterion(outputs, labels)

loss.backward()

optimizer.step()

running_loss+=loss.item()

if i % 2000==1999:

print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/2000))

running_loss=0.0

print('Finished Training')

在上述代码中,我们使用一个简单的循环来遍历训练数据的所有迭代。对于每个小批量样本,我们使用.zero_grad()函数将梯度设置为0,使用.backward()函数计算损失函数的梯度,使用.step()函数更新所有权重,并使用损失函数的值更新running_loss变量。

5. 测试模型

现在,我们已经在模型上进行了训练并得到了一些预测结果。我们可以使用测试数据集来评估模型在新数据上的性能。以下是在测试数据上评估模型性能的代码:

correct=0

total=0

with torch.no_grad():

for data in testloader:

images, labels=data

outputs=net(images)

_, predicted=torch.max(outputs.data, 1)

total+=labels.size(0)

correct+=(predicted==labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

使用测试数据集,我们可以计算模型的准确度。我们首先将预测的标签与实际的标签进行比较,计算正确分类的测试样本数,并将其除以测试的总样本数。

6. 结论

在这篇文章中,我们对MNIST手写数字识别进行了实现,使用pytorch实现一个全连接神经网络来识别手写数字。我们还介绍了如何下载、准备、可视化数据,如何训练和测试我们的模型。我们得到了85%以上的准确度,可作为一个好的开始。尝试更改超参数和模型体系结构,以获得更好的结果。

后端开发标签