PyTorch中的Batch Normalization

Pytorch中的BatchNorm的API主要有:

1 torch.nn.BatchNorm1d(num_features,
2 
3 eps=1e-05,
4 
5 momentum=0.1,
6 
7 affine=True,
8 
9 track_running_stats=True)

一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。

同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainningaffinetrack_running_stats

  • 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则0,并且不能学习被更新。一般都会设置成affine=True[10]
  • trainningtrack_running_statstrack_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

一般来说,trainningtrack_running_stats有四种组合[7]

  1. trainning=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。
  2. trainning=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。
  3. trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_meanrunning_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。[6,8]
  4. trainning=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。

同时,我们要注意到,BN层中的running_meanrunning_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。如

 1 model.train() # 处于训练状态
 2 
 3 
 4 for data, label in self.dataloader:
 5 
 6 pred = model(data)
 7 
 8 # 在这里就会更新model中的BN的统计特性参数,running_mean, running_var
 9 
10 loss = self.loss(pred, label)
11 
12 # 就算不要下列三行代码,BN的统计特性参数也会变化
13 
14 opt.zero_grad()
15 
16 loss.backward()
17 
18 opt.step()

这个时候要将model.eval()转到测试阶段,才能固定住running_meanrunning_var。有时候如果是先预训练模型然后加载模型,重新跑测试的时候结果不同,有一点性能上的损失,这个时候十有八九是trainningtrack_running_stats设置的不对,这里需要多注意。 [8]

[1]. 用pytorch踩过的坑

[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.

[3]. <深度学习优化策略-1>Batch Normalization(BN)

[4]. 详解深度学习中的Normalization,BN/LN/WN

[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24

[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870

[7]. BatchNorm2d增加的参数track_running_stats如何理解?

[8]. Why track_running_stats is not set to False during eval

[9]. How to train with frozen BatchNorm?

[10]. Proper way of fixing batchnorm layers during training

[11]. 大白话《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》