Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift
BN在神經網絡中很常見,BN是什麼?爲什麼要用BN? BN有什麼作用?接下來圍繞幾個點對BN進行總結,並附上BN層forward和backward代碼。正所謂,無總結,不進步。
一、BatchNormalization的引入
1、BN是2015年提出來的,論文題目是:《Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift》。題目有一個詞:internal Covariate Shift,它在原文的意思爲that small changes to the network parameters amplify as the network becomes deeper.The change in the distributions of layers’ inputs presents a problem because the layers need to continuously adapt to the new distribution. When the input distribution to a learning system changes, it is said to experience covariate shift.換句話說就是,當網絡越來越深的時候,參數的小小改變能使網絡產生很大的變化。也就是BN能夠降低參數改變帶來的變化。
2、若是網絡中沒有BN,會出現什麼問題?
根據求導法則:
當W很大的時候,根據鏈式求導法則,梯度g = g_local * g_out,即本次運算的梯度與前面梯度的乘積,當W小於1的時候,假設W爲0.9,經過100層的運算後,梯度爲0.9的100次方,是一個很小的值;當W大於1的時候,經過100層的運算,梯度是一個非常大的值,這兩種情況分別叫做梯度消失與梯度爆炸。
3、另一方面,對於激活層,有無BN對於激活的影響是什麼?
以激活函數sigmoid爲例,當數據接近於1或者0的時候,它的曲線是平緩的,也就是梯度會接近於0,當反向傳播的時候,梯度幾乎是不更新的,網絡無法達到訓練的目的。文中提到:BN constrain them to the linear regime of the nonlinearity. BN將值從約束到非線性區的相對線性區內,讓它們有梯度可以傳播。
二、什麼是BatchNormalization
根據論文公式,對於m個mini-batch的數據集,先獲取該batch的均值和方差,進行normalize,最後進行scale and shift。對於前面三步基本可以看懂,但是最後的scale and shift有什麼用,文章中也沒有明確說出來。我自己的理解是scale and shift是對歸一化後的數據進行偏移和尺度縮放,單一的normalize並不能滿足數據分佈的要求, scale and shift可以提高數據的信息表達,例如對於激活函數relu,小於0的部分不激活,但是如果數據用了scale and shift,使數據偏移,relu的激活量因此可以得到改變。當然,參數gamma和beta都是網絡可以學習的。
Batch Normalization的反向傳播,也是用求導的鏈式法則進行求梯度,求偏導中,loss對xi求偏導,偏導的結果與w無關,可以返回到第一部分提到的梯度消失和梯度爆炸的問題。
三、BN的優點有哪些
1、Batch Normalization enables higher learning rates
large learning rates may increase the scale of layer parameters, which then amplify the gradient during backpropagation and lead to the model explosion. However, with Batch Normalization, back-propagation through a layer is unaffected by the scale of its parameters.
2、Batch Normalization regularizes the model
BN有一定的正則化效果。在每一次的訓練中我們用的是mini-batch,用mini-batch的mean和variance來代表整個dataset的mean和variance,雖然用mini-batch是具有代表性的,但是它還不完全是dataset,等於給網絡增加了隨機噪音。有一定的正則化效果。
3、Accelerating BN Networks
提高了網絡的學習速度,每一層的輸入數據均值方差在一定的範圍內,使下一層網絡不必去適應輸入的變化,允許每一層進行獨立學習,有利於提高整個神經網絡的學習速度。
4、in some cases eliminating the need for Dropout.
減少對dropout的使用
5、reduce overfitting
降低過擬合,道理同2
四、BN的缺點有哪些
1、效果容易受batch size大小的影響。batch size越大,mini-batch的數據越有代表性,它的mean and variance越接近dataset的mean and variance。但是batch太大,內存不一定夠放。
2、難以在RNN中使用,RNN中更多的是使用Layer norm。
五、代碼實現
def batchnorm_forward(x, gamma, beta, bn_param):
"""
Forward pass for batch normalization.
During training the sample mean and (uncorrected) sample variance are
computed from minibatch statistics and used to normalize the incoming data.
During training we also keep an exponentially decaying running mean of the
mean and variance of each feature, and these averages are used to normalize
data at test-time.
At each timestep we update the running averages for mean and variance using
an exponential decay based on the momentum parameter:
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
Note that the batch normalization paper suggests a different test-time
behavior: they compute sample mean and variance for each feature using a
large number of training images rather than using a running average. For
this implementation we have chosen to use running averages instead since
they do not require an additional estimation step; the torch7
implementation of batch normalization also uses running averages.
Input:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift paremeter of shape (D,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'; required
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance.
- running_mean: Array of shape (D,) giving running mean of features
- running_var Array of shape (D,) giving running variance of features
Returns a tuple of:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cache = None, None
if mode == 'train':
#######################################################################
# TODO: Implement the training-time forward pass for batch norm. #
# Use minibatch statistics to compute the mean and variance, use #
# these statistics to normalize the incoming data, and scale and #
# shift the normalized data using gamma and beta. #
# # #
# Note that though you should be keeping track of the running #
# variance, you should normalize the data based on the standard #
# deviation (square root of variance) instead! #
# Referencing the original paper (https://arxiv.org/abs/1502.03167) #
# might prove to be helpful. #
#######################################################################
sample_mean = np.mean(x, axis=0) #[D]
sample_var = np.var(x, axis=0) #[D]
x_hat = (x - sample_mean) / np.sqrt(sample_var+eps)
out = gamma*x_hat+beta
cache = (gamma, x, sample_mean, sample_var, eps, x_hat) #why is this
#store the global mean and var
running_mean = momentum* running_mean +(1-momentum)*sample_mean
running_var = momentum*running_var+(1-momentum)*sample_var
elif mode == 'test':
#######################################################################
# TODO: Implement the test-time forward pass for batch normalization. #
# Use the running mean and variance to normalize the incoming data, #
# then scale and shift the normalized data using gamma and beta. #
# Store the result in the out variable. #
#######################################################################
scale = gamma / np.sqrt(running_var+eps)
shift = beta- scale*running_mean
out = scale * x + shift
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
def batchnorm_backward(dout, cache):
"""
Backward pass for batch normalization.
For this implementation, you should write out a computation graph for
batch normalization on paper and propagate gradients backward through
intermediate nodes.
Inputs:
- dout: Upstream derivatives, of shape (N, D)
- cache: Variable of intermediates from batchnorm_forward.
Returns a tuple of:
- dx: Gradient with respect to inputs x, of shape (N, D)
- dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
- dbeta: Gradient with respect to shift parameter beta, of shape (D,)
"""
dx, dgamma, dbeta = None, None, None
###########################################################################
# TODO: Implement the backward pass for batch normalization. Store the #
# results in the dx, dgamma, and dbeta variables. #
# Referencing the original paper (https://arxiv.org/abs/1502.03167) #
# might prove to be helpful. #
###########################################################################
gamma, x, sample_mean, sample_var, eps, x_hat = cache
N = x.shape[0]
dx_hat = dout * gamma
dvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)
dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)
dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmean
dgamma = np.sum(x_hat * dout, axis = 0)
dbeta = np.sum(dout , axis = 0)
return dx, dgamma, dbeta
# test
np.random.seed(231)
N, D = 4, 5
x = 5 * np.random.randn(N, D) + 12
gamma = np.random.randn(D)
beta = np.random.randn(D)
dout = np.random.randn(N, D)
bn_param = {'mode': 'train'}
_, cache = batchnorm_forward(x, gamma, beta, bn_param)
dx, dgamma, dbeta = batchnorm_backward(dout, cache)