Batch Normalization做了什么?
在数据在最初进来的时候,都希望是(IID)独立同分布的。
但是batch Normalization的作者觉得不够,应该在deep learning中的每层都进行一次处理,保证在每层都是同分布。
他是这么想的:假设网络有n层,网络正在训练,还没有收敛。这时候被输入,经过了第一层,但是第一层还没有学到正确的weight,所以经过weight的矩阵乘法后,第二层的数会不会很乱?会不会第二层有些节点值是个位数,有些节点值蹦到好几百?细想一下,确实挺有可能啊,内部的参数都是随机初始化的,那蹦啥结果确实不好说啊。然后恐怖的事情来了,第二层这些乱蹦的数,又输到了第三层,那第三层的输入就是乱蹦的数,输出当然好不了,以此类推。
所以主要产生了两个问题:
1.所以在前面的网络没有收敛的时候,后面的网络其实并学不到什么。(一栋大楼底部都是晃的,那上面也好不了。所以必须要等前面的层收敛后,后面层的训练才有效果。)
2.因为一般来说网络内部每层都需要加一层激活来增加非线性化嘛,那么如果值比较大,它通过激活以后在S曲线上会比较接近0或1,梯度很小,收敛会很慢。
所以batch Normalization就想在每层都加一个norm进行标准化,让每层的数分布相同,变成均值0,方差1的标准分布。高斯分布的标准化公式就是下面式子中的部分。值减去均值()再除以方差(),能够得到均值为0,方差为1的标准正态分布。至于γ和β,是需要学习的两个参数,γ对数据的方差再进行一个缩放,β对数据的均值产生一个偏移。
为什么归一化成均值0,方差1后,还要再修改方差和均值?那归一化还有意义吗?
这是因为我们并不能保证这层网络学到的特征是什么,如果简单的归一化,很有可能会被破坏。比如说S型激活函数,如果这层学到的特征在S的顶端那块,那么我们做归一化以后,强行把特征带到了S的中间位置,特征就被破坏了。要注意γ和β是被训练的参数,且每层都不一样,所以针对每一层的实际情况,它会去尝试恢复这层网络所学到的特征。
实际工作:
Batch Norm 只是插入在隐藏层和下一个隐藏层之间的另一个网络层。 它的工作是从第一个隐藏层获取输出并在将它们作为下一个隐藏层的输入传递之前对其进行标准化。
参数:两个可学习的参数, beta 和 gamma。
Batch Norm 层的计算:
(建议Batch Norm在激活函数前加——防止激活时梯度弥散)
1. 激活:来自前一层的激活作为输入传递给 Batch Norm。数据中的每个特征都有一个激活向量。
2. 计算均值和方差:每个激活向量分别计算 mini-batch 中所有值的均值和方差。
3. 规范化:使用相应的均值和方差计算每个激活特征向量的归一化值。这些归一化值现在有零均值和单位方差。
4. 规模和转移:这一步是 Batch Norm 引入的创新点。与要求所有归一化值的均值和单位方差为零的输入层不同,Batch Norm 允许将其值移动(到不同的均值)和缩放(到不同的方差)。它通过将归一化值乘以因子 gamma 并添加因子 beta 来实现此目的。这里是逐元素乘法,而不是矩阵乘法。创新点在于,这些因素不是超参数(即模型设计者提供的常数),而是网络学习的可训练参数。每个 Batch Norm 层都能够为自己找到最佳因子,因此可以移动和缩放归一化值以获得最佳预测。
5. 移动平均线:Batch Norm 还保持对均值和方差的指数移动平均线 (EMA) 的运行计数。训练期间它只是计算这个 EMA,但不做任何处理。在训练结束时,它将该值保存为层状态的一部分,以在推理阶段使用。移动平均线计算使用由下面的 alpha 表示的标量“动量”。这是一个仅用于 Batch Norm 移动平均线的超参数,不应与优化器中使用的动量混淆。
代码实现:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1=nn.Sequential(
nn.Conv2d(3,3,3,1,1,bias=False),
#因为BatchNorm的计算里面加了bias,所以一开始要bias=False
nn.BatchNorm2d(3),#对卷积使用BatchNorm2d,通道数必须大于1
nn.ReLU(),
nn.Conv2d(3,3,3,1,1)
)
self.layer2=nn.Sequential(
nn.Linear(3*5*5,20,bias=False),
nn.BatchNorm1d(20),#对卷积使用BatchNorm2d,批次数必须大于1
nn.Linear(20,1),
)
def forward(self,x):
OUT=self.layer1(x)
OUT=OUT.reshape(-1,3*5*5)#NCHW-NV
return self.layer2(OUT)
if __name__ == '__main__':
net=Net()
x=torch.randn(2,3,5,5)
y=net(x)
print(y)
print(y.shape)
版权归原作者 小羊头发长 所有, 如有侵权,请联系我们删除。