pytorch之添加BN的实现

1. 什么是BN

BN是Batch Normalization的缩写,官方文档的定义为:

Batch Normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1.

简单来说,就是一种归一化处理的方法,用于深度学习中神经网络的训练。

2. 为什么需要BN

深度神经网络的训练过程中,经常会出现梯度消失或爆炸的问题,导致模型无法收敛。而BN的出现解决了这个问题。

2.1 梯度消失和爆炸

在深度神经网络中,每一层都是由一堆神经元构成的,假设有L层网络,第i层的输入为$x^{(i)}$,第i层的输出为$y^{(i)}$,则第i+1层的输入为$z^{(i+1)}$,第i+1层的输出为$a^{(i+1)}$,且$a^{(i+1)}=g(z^{(i+1)})$,其中g为非线性函数。

在BP过程中,需要算出损失函数对每个参数的导数,即梯度,然后反向传播更新所有参数。假设激活函数是sigmoid函数,当$x$取到很大或很小的数时,其导数将会趋于0,这个现象也被称为梯度消失,层数越深,这个问题会越严重。而在反向传播过程中,误差也可能会随着层数增加而指数级爆炸,这个问题被称为梯度爆炸。

2.2 BN的思路

Batch Normalization是一种对数据进行归一化处理的方式,它使得数据满足均值为0,方差为1的高斯分布,使得激活函数的导数变得稳定,防止梯度消失或爆炸。

3. BN的实现

PyTorch中实现BN的层可以使用nn.BatchNorm2d()函数,它的参数有:

num_features:输入的特征数量,对于2D卷积层来说,就是输出通道数。

eps:为避免除以0,加上的一个小值,默认为1e-5。

momentum:用于求解滑动平均的动量,默认为0.1。

affine:是否在归一化的基础上,再进行仿射变换,默认为True。

track_running_stats:是否使用全局均值和方差进行归一化处理,True则采用全局均值和方差,之后每一批次再进行标准化。默认为True。

使用该函数时,需要保证输入数据为4D,即(batch_size, channel, height, width)。

3.1 添加BN的实现

在训练神经网络时,可以将BN层添加到卷积层或线性层之后。

import torch.nn as nn

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()

self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)

self.bn1 = nn.BatchNorm2d(32)

self.relu = nn.ReLU(inplace=True)

self.fc1 = nn.Linear(32*28*28, 10, bias=False)

def forward(self, x):

x = self.conv1(x)

x = self.bn1(x)

x = self.relu(x)

x = x.view(x.size(0), -1)

x = self.fc1(x)

return x

上述代码中,首先创建一个卷积层self.conv1和一个BN层self.bn1,然后将BN层添加到卷积层之后,再添加一个ReLU激活函数self.relu。最后添加一个线性层self.fc1,将输出展平后用于分类。

3.2 调节temperature

在实现中,还可以添加temperature参数,用于控制归一化程度。temperature代表的是温度,当temperature越大时,模型的输出分布就越平滑。当temperature越小时,模型的输出分布就越尖锐,即更加集中在某个特定的值附近。在BN的计算公式中,temperature就扮演了这样的角色,它可以使得模型的输出更有创造性和多样性。

向BN层中添加temperature的代码如下:

class myBN(nn.Module):

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, temperature=1.0):

super(myBN, self).__init__()

self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)

self.temperature = temperature

def forward(self, x):

out = self.bn(x)

out = self.temperature * out

return out

上述代码中,首先继承nn.Module类,然后在__init__()函数中添加temperature参数,并在forward()函数中使用。在forward()函数中,首先使用nn.BatchNorm2d()计算标准化值out,然后将标准化值乘以temperature来缩小或扩大输出分布的范围。

4. 总结

Batch Normalization是一种常用的神经网络训练的技术,能够解决梯度消失或爆炸的问题,使得模型训练更加稳定和高效。在PyTorch中,可以通过nn.BatchNorm2d()函数来实现BN层的添加,也可以自己定义一个BN层并添加temperature参数,用于控制输出分布的范围。

后端开发标签