PyTorch实现MNIST数据集手写数字识别详情

1. 简介

MNIST是一个手写数字数据集,被广泛应用于机器学习领域。本文将使用PyTorch框架,训练一个神经网络模型,用于识别MNIST数据集中的手写数字。

2. 数据集介绍

MNIST数据集包含了60000张训练集图片和10000张测试集图片。每张图片都是28*28的像素,表示手写数字0-9中的一种。数据集已经被标记,因此可以使用它训练和测试神经网络模型。

在使用PyTorch加载MNIST数据集时,需要将图片数据转换为torch.Tensor类型,并且进行归一化操作,将像素值转换为0-1范围内的小数值。

import torch

from torchvision import datasets, transforms

# 定义数据预处理操作

data_transforms = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))

])

# 加载训练集和测试集数据

trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=data_transforms)

testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=data_transforms)

# 每次从数据集中加载batch_size张图片进行训练或测试

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

3. 神经网络模型

3.1 模型构建

本文将使用一个包含两个卷积层和两个全连接层的神经网络模型。卷积层将提取图像中的特征,全连接层将将特征进行分类。下面是模型的具体结构:

import torch.nn as nn

import torch.nn.functional as F

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()

self.conv1 = nn.Conv2d(1, 32, 3, 1)

self.conv2 = nn.Conv2d(32, 64, 3, 1)

self.dropout1 = nn.Dropout2d(0.25)

self.dropout2 = nn.Dropout2d(0.5)

self.fc1 = nn.Linear(9216, 128)

self.fc2 = nn.Linear(128, 10)

def forward(self, x):

x = self.conv1(x)

x = F.relu(x)

x = self.conv2(x)

x = F.relu(x)

x = F.max_pool2d(x, 2)

x = self.dropout1(x)

x = torch.flatten(x, 1)

x = self.fc1(x)

x = F.relu(x)

x = self.dropout2(x)

x = self.fc2(x)

return x

model = Net()

3.2 模型参数

在训练神经网络模型时,需要定义优化器和损失函数,并且对模型参数进行初始化。

优化器使用随机梯度下降(SGD)算法,学习率为0.01。

import torch.optim as optim

# 定义优化器和损失函数

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.01)

# 初始化模型参数

def weight_init(m):

if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):

nn.init.kaiming_uniform_(m.weight)

model.apply(weight_init)

4. 模型训练

4.1 训练函数

下面是训练神经网络模型的函数。

首先,将模型设置为训练模式,为了开启Dropout层的随机失活操作;然后,遍历训练集中的每一批数据,将数据送入模型进行前向传播并计算损失函数的值;接着,将梯度清零,进行反向传播,并调用优化器更新参数;最后,计算在测试集上的准确率。

def train(model, device, trainloader, optimizer, epoch):

model.train()

for batch_idx, (data, target) in enumerate(trainloader):

data, target = data.to(device), target.to(device)

optimizer.zero_grad()

output = model(data)

loss = criterion(output, target)

loss.backward()

optimizer.step()

def test(model, device, testloader):

model.eval()

test_loss = 0

correct = 0

with torch.no_grad():

for data, target in testloader:

data, target = data.to(device), target.to(device)

output = model(data)

test_loss += criterion(output, target).item()

pred = output.argmax(dim=1, keepdim=True)

correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

accuracy = 100. * correct / len(testloader.dataset)

return test_loss, accuracy

4.2 训练过程

使用上面定义的训练函数,遍历训练集10次,同时计算在测试集上的准确率。在每次遍历训练集前,将训练集乱序,以增加模型的稳定性。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

for epoch in range(10):

train_loss = train(model, device, trainloader, optimizer, epoch)

test_loss, accuracy = test(model, device, testloader)

print(f'Epoch {epoch} Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

5. 模型预测

训练完成后,可以使用训练好的模型对新的手写数字进行预测。下面是一个样例代码,输入一张手写数字图片,输出预测结果的概率分布。

import torch.nn.functional as F

from PIL import Image

def predict(model, device, image_path):

model.eval()

image = Image.open(image_path).convert('L')

data_transform = transforms.Compose([

transforms.Resize((28, 28)),

transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))

])

image_tensor = data_transform(image)

image_tensor = image_tensor.unsqueeze(0)

image_tensor = image_tensor.to(device)

output = model(image_tensor)

probabilities = F.softmax(output, dim=1)

return probabilities.squeeze().tolist()

image_path = './test_image.png'

probabilities = predict(model, device, image_path)

print(probabilities)

完整的代码已经上传至Github:https://github.com/Jerryrong0927/PyTorch-MNIST

免责声明:本文来自互联网,本站所有信息(包括但不限于文字、视频、音频、数据及图表),不保证该信息的准确性、真实性、完整性、有效性、及时性、原创性等,版权归属于原作者,如无意侵犯媒体或个人知识产权,请来电或致函告之,本站将在第一时间处理。猿码集站发布此文目的在于促进信息交流,此文观点与本站立场无关,不承担任何责任。

后端开发标签