pytorch实现mnist手写彩色数字识别

1. 简介

Mnist是一个非常经典的手写数字数据集,它包含了0-9的数字图片,每张图片大小是28x28像素,我们需要对这些图片进行分类。本文将使用Pytorch框架实现手写彩色数字识别,其中输入为彩色图片。

2. 数据集介绍

2.1 Mnist数据集

Mnist数据集包含了60000个训练集样本和10000个测试集样本。每张图片的大小为28x28,其中黑色表示像素值为0,白色表示像素值为255。我们可以将每个像素点的值除以255,得到的值即为0-1之间的值,这样可以使模型的学习更加稳定和快速。

2.2 加入彩色维度的Mnist数据集

为了加入彩色维度,我们需要对Mnist数据集中的黑白图片进行转换。转换的方式是将黑白图片复制三份,形成三个通道。即可得到大小为28x28x3的彩色图片。

3. 数据处理

在进行数据处理之前,我们需要先安装必要的库文件。本文将使用Pytorch框架完成手写数字图片的彩色分类。

# 安装必要的库文件

!pip install torch

!pip install torchvision

!pip install matplotlib

我们需要准备好Mnist数据集,我们可以通过torchvision库中提供的接口直接下载数据集。

# 加载Mnist数据集

import torchvision

import torchvision.transforms as transforms

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

我们可以通过Matplotlib库将Mnist数据集可视化,然后我们可以看到Mnist数据集中的数字均在28x28像素内,占据了图片大部分的区域。

# 可视化Mnist数据集

import matplotlib.pyplot as plt

def show_image(img):

img = img.numpy()

plt.imshow(img.transpose(1, 2, 0), interpolation='nearest')

plt.show()

image, label = train_dataset[10]

show_image(image)

4. 构建模型

我们将使用以下卷积神经网络进行手写数字彩色图片的分类。该模型包含两个卷积层,一个全连接层和一个输出层,其中最后一层输出为10,表示我们需要对图片进行10类的分类。

# 构建彩色手写数字图片分类模型

import torch.nn as nn

import torch.nn.functional as F

class Net(nn.Module):

def __init__(self):

super().__init__()

self.conv1 = nn.Conv2d(3, 16, 3, padding=1)

self.bn1 = nn.BatchNorm2d(16)

self.conv2 = nn.Conv2d(16, 32, 3, padding=1)

self.bn2 = nn.BatchNorm2d(32)

self.fc1 = nn.Linear(32 * 7 * 7, 100)

self.fc2 = nn.Linear(100, 10)

def forward(self, x):

x = F.relu(self.bn1(self.conv1(x)))

x = F.max_pool2d(x, 2)

x = F.relu(self.bn2(self.conv2(x)))

x = F.max_pool2d(x, 2)

x = x.view(-1, 32 * 7 * 7)

x = F.relu(self.fc1(x))

x = self.fc2(x)

return x

net = Net()

5. 训练模型

下面我们开始训练模型。我们使用的是交叉熵损失函数和Adam优化器。每一轮的训练,我们都需要对训练数据进行打乱处理,以避免模型只学习某个特定序列的数据。

# 训练网络

import torch.optim as optim

NUM_EPOCHS = 5

BATCH_SIZE = 128

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(net.parameters())

def train(epoch):

running_loss = 0.0

for i, data in enumerate(train_loader, 0):

inputs, labels = data

optimizer.zero_grad()

outputs = net(inputs)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

running_loss += loss.item()

print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(train_loader)))

def test():

correct = 0

loss = 0.0

total = 0

with torch.no_grad():

for data in test_loader:

images, labels = data

outputs = net(images)

loss += criterion(outputs, labels).item()

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

print('Accuracy on the test images: %d %%' % (100 * correct / total))

for epoch in range(NUM_EPOCHS):

train(epoch)

test()

我们可以看到每一轮训练的结果,以及训练完成后在测试集上的准确率。

6. 结果可视化

下面我们随机抽取一张彩色手写数字图片进行测试,将测试结果可视化出来。我们可以看到,模型正确地识别了手写数字“5”。

# 可视化结果

import numpy as np

import torchvision.transforms.functional as F

image, label = test_dataset[777]

img = F.to_pil_image(image)

plt.imshow(np.asarray(img))

plt.show()

net.eval()

with torch.no_grad():

output = net(image.unsqueeze(0))

pred = output.argmax(dim=1)

print("Predicted digit: ", pred.item())

在本次实验中,我们采用了Pytorch框架,通过卷积神经网络建立了一个手写数字的彩色图片分类模型,并在Mnist数据集上进行了训练。我们可以看到,在训练5轮后,模型可以在测试集上达到98%的分类准确率。

后端开发标签