1. 什么是BN
BN是Batch Normalization的缩写,官方文档的定义为:
Batch Normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1.
简单来说,就是一种归一化处理的方法,用于深度学习中神经网络的训练。
2. 为什么需要BN
深度神经网络的训练过程中,经常会出现梯度消失或爆炸的问题,导致模型无法收敛。而BN的出现解决了这个问题。
2.1 梯度消失和爆炸
在深度神经网络中,每一层都是由一堆神经元构成的,假设有L层网络,第i层的输入为$x^{(i)}$,第i层的输出为$y^{(i)}$,则第i+1层的输入为$z^{(i+1)}$,第i+1层的输出为$a^{(i+1)}$,且$a^{(i+1)}=g(z^{(i+1)})$,其中g为非线性函数。
在BP过程中,需要算出损失函数对每个参数的导数,即梯度,然后反向传播更新所有参数。假设激活函数是sigmoid函数,当$x$取到很大或很小的数时,其导数将会趋于0,这个现象也被称为梯度消失,层数越深,这个问题会越严重。而在反向传播过程中,误差也可能会随着层数增加而指数级爆炸,这个问题被称为梯度爆炸。
2.2 BN的思路
Batch Normalization是一种对数据进行归一化处理的方式,它使得数据满足均值为0,方差为1的高斯分布,使得激活函数的导数变得稳定,防止梯度消失或爆炸。
3. BN的实现
PyTorch中实现BN的层可以使用nn.BatchNorm2d()函数,它的参数有:
num_features:输入的特征数量,对于2D卷积层来说,就是输出通道数。
eps:为避免除以0,加上的一个小值,默认为1e-5。
momentum:用于求解滑动平均的动量,默认为0.1。
affine:是否在归一化的基础上,再进行仿射变换,默认为True。
track_running_stats:是否使用全局均值和方差进行归一化处理,True则采用全局均值和方差,之后每一批次再进行标准化。默认为True。
使用该函数时,需要保证输入数据为4D,即(batch_size, channel, height, width)。
3.1 添加BN的实现
在训练神经网络时,可以将BN层添加到卷积层或线性层之后。
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(32*28*28, 10, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
上述代码中,首先创建一个卷积层self.conv1和一个BN层self.bn1,然后将BN层添加到卷积层之后,再添加一个ReLU激活函数self.relu。最后添加一个线性层self.fc1,将输出展平后用于分类。
3.2 调节temperature
在实现中,还可以添加temperature参数,用于控制归一化程度。temperature代表的是温度,当temperature越大时,模型的输出分布就越平滑。当temperature越小时,模型的输出分布就越尖锐,即更加集中在某个特定的值附近。在BN的计算公式中,temperature就扮演了这样的角色,它可以使得模型的输出更有创造性和多样性。
向BN层中添加temperature的代码如下:
class myBN(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, temperature=1.0):
super(myBN, self).__init__()
self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.temperature = temperature
def forward(self, x):
out = self.bn(x)
out = self.temperature * out
return out
上述代码中,首先继承nn.Module类,然后在__init__()函数中添加temperature参数,并在forward()函数中使用。在forward()函数中,首先使用nn.BatchNorm2d()计算标准化值out,然后将标准化值乘以temperature来缩小或扩大输出分布的范围。
4. 总结
Batch Normalization是一种常用的神经网络训练的技术,能够解决梯度消失或爆炸的问题,使得模型训练更加稳定和高效。在PyTorch中,可以通过nn.BatchNorm2d()函数来实现BN层的添加,也可以自己定义一个BN层并添加temperature参数,用于控制输出分布的范围。