pytorch SENet实现案例

1. SENet介绍

SENet即Squeeze-and-Excitation Networks,是由Jie Hu等人在2017年提出的一种神经网络结构,它通过将Global Average Pooling(GAP)Fully Connected(FC)结合构建了一种通用的网络模块,该模块可以有效地提升深度卷积神经网络的性能。

1.1 Squeeze:信息压缩

在SENet中,Squeeze指的是将GAP操作应用于卷积特征图的每个通道,将每个通道的特征压缩成一个实数,即每个通道产生一个长度为C的向量,然后在通道维度上对所有的通道特征做一个元素普通的加权平均,以获取整张特征图的有助于区分不同类别的全局感知信息。

import torch

import torch.nn as nn

class SEBlock(nn.Module):

def __init__(self, in_channels, reduction=16):

super(SEBlock, self).__init__()

self.avg_pool = nn.AdaptiveAvgPool2d(1)

self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)

self.relu = nn.ReLU(inplace=True)

self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False)

self.sigmoid = nn.Sigmoid()

def forward(self, x):

b, c, h, w = x.size()

y = self.avg_pool(x)

y = self.fc1(y)

y = self.relu(y)

y = self.fc2(y)

y = self.sigmoid(y)

return x * y

上面的代码展示了SEBlock的实现,它的输入是一个四维张量(B,C,H,W),其中,B表示batch size,C表示输入通道数,H和W分别表示输入的高和宽。SEBlock可以从输入通道数进行首次压缩,然后再将维度增加到原来的大小,以使通道之间的关系更加复杂。

1.2 Excitation:特征重新加权

Excitation指的是使用一组可学习的参数对SEBlock的每个通道特征进行加权。 SEBlock的输出是每个通道的加权的特征图,该过程可以看作是一种知道每个通道对分类的作用的动态特征选择机制。

2. pytorch实现SENet

在pytorch中,我们可以很容易地实现SENet,下面是一个使用SENet模块的卷积神经网络模型:

from torch import nn

class SENet(nn.Module):

def __init__(self, num_classes=10):

super(SENet, self).__init__()

self.features = nn.Sequential(

nn.Conv2d(3, 64, kernel_size=3, padding=1),

nn.BatchNorm2d(64),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=2, stride=2),

nn.Conv2d(64, 128, kernel_size=3, padding=1),

nn.BatchNorm2d(128),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=2, stride=2),

nn.Conv2d(128, 256, kernel_size=3, padding=1),

nn.BatchNorm2d(256),

nn.ReLU(inplace=True),

nn.Conv2d(256, 512, kernel_size=3, padding=1),

nn.BatchNorm2d(512),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=2, stride=2),

nn.Conv2d(512, 1024, kernel_size=3, padding=1),

SEBlock(1024),

nn.BatchNorm2d(1024),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=2, stride=2),

)

self.avg_pool = nn.AdaptiveAvgPool2d(1)

self.classifier = nn.Sequential(

nn.Linear(1024, num_classes)

)

def forward(self, x):

x = self.features(x)

x = self.avg_pool(x)

x = x.view(x.size(0), -1)

x = self.classifier(x)

return x

该模型包含一个SEBlock,可以自定义通道数和压缩比。下面是使用该模型对CIFAR10数据集进行训练的代码:

import torch

import torch.nn as nn

import torchvision.datasets as dsets

import torchvision.transforms as transforms

from torch.autograd import Variable

# hyperparameters

num_epochs = 10

batch_size = 128

learning_rate = 0.001

# CIFAR-10 dataset

train_dataset = dsets.CIFAR10(root='./data/',

train=True,

transform=transforms.ToTensor(),

download=True)

test_dataset = dsets.CIFAR10(root='./data/',

train=False,

transform=transforms.ToTensor())

# Data loader

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,

batch_size=batch_size,

shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,

batch_size=batch_size,

shuffle=False)

# Model

model = SENet(num_classes=10)

# Loss and Optimizer

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the Model

for epoch in range(num_epochs):

for i, (images, labels) in enumerate(train_loader):

images = Variable(images)

labels = Variable(labels)

# Forward + Backward + Optimize

optimizer.zero_grad()

outputs = model(images)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

if (i+1) % 100 == 0:

print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'

% (epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data))

# Test the Model

model.eval()

correct = 0

total = 0

for images, labels in test_loader:

images = Variable(images)

outputs = model(images)

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum()

print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))

运行以上代码训练模型,可以得到在测试集上准确率为80%以上的结果。

免责声明:本文来自互联网,本站所有信息(包括但不限于文字、视频、音频、数据及图表),不保证该信息的准确性、真实性、完整性、有效性、及时性、原创性等,版权归属于原作者,如无意侵犯媒体或个人知识产权,请来电或致函告之,本站将在第一时间处理。猿码集站发布此文目的在于促进信息交流,此文观点与本站立场无关,不承担任何责任。

后端开发标签