手把手教你用Pytorch训练图像分类网络

1. Pytorch简介

Pytorch是一个基于Python的科学计算包,它是一个在动态图机制上构建的深度学习框架。它是目前广泛使用的深度学习框架之一,首先由Facebook的人工智能研究团队研发并开源。与TensorFlow相比,Pytorch更加灵活,易于使用,具有更高的编程效率。

2. 图像分类问题

图像分类是计算机视觉任务中的一项重要问题,它的目标是将输入的图片分成不同的类别。对于图像分类问题,我们一般采用卷积神经网络(Convolutional Neural Network,CNN)来进行处理。

2.1 卷积神经网络

卷积神经网络是一种深度学习神经网络,它的主要特点是使用卷积层(Convolutional Layer)和池化层(Pooling Layer)对输入的图像进行卷积和池化操作,从而获取图像的特征信息。卷积层和池化层的操作可以减少网络中的参数量,提高网络对图像的鲁棒性。

Pytorch中已经实现了常见的卷积神经网络模型,我们可以直接使用Pytorch提供的模型进行训练和预测。在本文中,我们将使用Pytorch来训练一个图像分类的卷积神经网络。

3. 数据集

在训练深度学习模型时,选择合适的数据集是非常重要的。对于图像分类问题,我们可以使用CIFAR-10数据集,它包含了10个类别的60000张32*32的彩色图像,每个类别有6000张图片。

在Pytorch中可以使用torchvision来加载CIFAR-10数据集。我们可以使用以下代码来加载训练集和测试集:

import torch

import torchvision

import torchvision.transforms as transforms

transform_train = transforms.Compose(

[transforms.RandomCrop(32, padding=4),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465),

(0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose(

[transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465),

(0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

download=True,

transform=transform_train)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,

shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,

download=True,

transform=transform_test)

testloader = torch.utils.data.DataLoader(testset, batch_size=100,

shuffle=False, num_workers=2)

在上面的代码中,我们使用了transforms来对图像进行预处理。训练集和测试集的预处理方式稍有不同,其中trainset采用了随机裁剪和水平翻转的方式进行数据增强,以提高模型的泛化能力;而testset只进行了简单的归一化操作。

4. 网络构建和训练

在本文中,我们将采用一个简单的卷积神经网络来对CIFAR-10数据集进行分类。该网络包含两个卷积层、两个池化层和三个全连接层。具体的网络结构如下:

import torch.nn as nn

import torch.nn.functional as F

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()

self.conv1 = nn.Conv2d(3, 64, 3, padding=1)

self.conv2 = nn.Conv2d(64, 128, 3, padding=1)

self.fc1 = nn.Linear(8192, 1024)

self.fc2 = nn.Linear(1024, 512)

self.fc3 = nn.Linear(512, 10)

def forward(self, x):

x = F.relu(self.conv1(x))

x = F.max_pool2d(x, 2)

x = F.relu(self.conv2(x))

x = F.max_pool2d(x, 2)

x = x.view(-1, 8192)

x = F.relu(self.fc1(x))

x = F.relu(self.fc2(x))

x = self.fc3(x)

return x

在上述代码中,我们定义了一个名为Net的继承自nn.Module的类,并在类构造函数中定义了网络的结构。在函数forward中,我们对x进行卷积、池化和全连接操作,最终得到网络预测的结果。

接下来,我们可以在训练集上训练该网络。训练的代码如下:

import torch.optim as optim

net = Net()

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10): # 对数据集进行多次训练

running_loss = 0.0

for i, data in enumerate(trainloader, 0):

inputs, labels = data

optimizer.zero_grad()

outputs = net(inputs)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

running_loss += loss.item()

if i % 100 == 99: # 每100个batch打印一次平均loss

print('[%d, %5d] loss: %.3f' %

(epoch + 1, i + 1, running_loss / 100))

running_loss = 0.0

在上述代码中,我们首先定义了损失函数和优化器,然后对训练集进行多次训练。训练时,我们依次从trainloader中加载一个batch的数据,将网络输入为inputs,标签为labels,并进行前向传播和反向传播。每100个batch,我们打印一次平均loss。

5. 测试和预测

在模型训练完成后,我们可以使用测试集来评估模型的性能。测试的代码如下:

correct = 0

total = 0

with torch.no_grad():

for data in testloader:

images, labels = data

outputs = net(images)

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (

100 * correct / total))

在上述代码中,我们使用了torch.no_grad()来关闭梯度计算,从而减少计算量。对于测试集中的每一个batch,我们对网络进行前向传播并求得预测结果和真实标签。最后,我们计算分类准确率。

除了计算分类准确率之外,我们还可以对单张图片进行预测。预测的代码如下:

import torchvision.transforms.functional as F2

from PIL import Image

def predict(image_path, model):

image = Image.open(image_path)

image = F2.resize(image, (32, 32))

image_tensor = transform_test(image).unsqueeze_(0)

output = model(image_tensor)

_, predicted = torch.max(output.data, 1)

return int(predicted[0])

image_path = 'test.jpg'

label = predict(image_path, net)

print('Predicted label:', label)

在上述代码中,我们首先打开一张图片,然后进行缩放和归一化操作,并将其转化为模型输入的tensor。对于tensor,我们调用模型的forward函数进行预测,并返回预测结果。

总结

本文介绍了如何使用Pytorch来训练一个图像分类的卷积神经网络,并给出了具体的代码实现。通过本文的学习,我们可以了解到Pytorch的基本使用方法,以及如何对图像数据集进行预处理和网络构建。在实际应用中,我们可以根据需要对网络结构进行更改,并进行超参数的调节,从而得到更好的分类结果。

后端开发标签