Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift

Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift


       BN在神經網絡中很常見,BN是什麼?爲什麼要用BN? BN有什麼作用?接下來圍繞幾個點對BN進行總結,並附上BN層forward和backward代碼。正所謂,無總結,不進步

一、BatchNormalization的引入

1、BN是2015年提出來的,論文題目是:《Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift》。題目有一個詞:internal Covariate Shift,它在原文的意思爲that small changes to the network parameters amplify as the network becomes deeper.The change in the distributions of layers’ inputs presents a problem because the layers need to continuously adapt to the new distribution. When the input distribution to a learning system changes, it is said to experience covariate shift.換句話說就是,當網絡越來越深的時候,參數的小小改變能使網絡產生很大的變化。也就是BN能夠降低參數改變帶來的變化。

2、若是網絡中沒有BN,會出現什麼問題?
根據求導法則:
在這裏插入圖片描述
在這裏插入圖片描述
當W很大的時候,根據鏈式求導法則,梯度g = g_local * g_out,即本次運算的梯度與前面梯度的乘積,當W小於1的時候,假設W爲0.9,經過100層的運算後,梯度爲0.9的100次方,是一個很小的值;當W大於1的時候,經過100層的運算,梯度是一個非常大的值,這兩種情況分別叫做梯度消失與梯度爆炸。

3、另一方面,對於激活層,有無BN對於激活的影響是什麼?
以激活函數sigmoid爲例,當數據接近於1或者0的時候,它的曲線是平緩的,也就是梯度會接近於0,當反向傳播的時候,梯度幾乎是不更新的,網絡無法達到訓練的目的。文中提到:BN constrain them to the linear regime of the nonlinearity. BN將值從約束到非線性區的相對線性區內,讓它們有梯度可以傳播。
在這裏插入圖片描述

二、什麼是BatchNormalization

       根據論文公式,對於m個mini-batch的數據集,先獲取該batch的均值和方差,進行normalize,最後進行scale and shift。對於前面三步基本可以看懂,但是最後的scale and shift有什麼用,文章中也沒有明確說出來。我自己的理解是scale and shift是對歸一化後的數據進行偏移和尺度縮放,單一的normalize並不能滿足數據分佈的要求, scale and shift可以提高數據的信息表達,例如對於激活函數relu,小於0的部分不激活,但是如果數據用了scale and shift,使數據偏移,relu的激活量因此可以得到改變。當然,參數gamma和beta都是網絡可以學習的。
在這裏插入圖片描述
Batch Normalization的反向傳播,也是用求導的鏈式法則進行求梯度,求偏導中,loss對xi求偏導,偏導的結果與w無關,可以返回到第一部分提到的梯度消失和梯度爆炸的問題。
在這裏插入圖片描述

三、BN的優點有哪些

1、Batch Normalization enables higher learning rates
       large learning rates may increase the scale of layer parameters, which then amplify the gradient during backpropagation and lead to the model explosion. However, with Batch Normalization, back-propagation through a layer is unaffected by the scale of its parameters.

2、Batch Normalization regularizes the model
       BN有一定的正則化效果。在每一次的訓練中我們用的是mini-batch,用mini-batch的mean和variance來代表整個dataset的mean和variance,雖然用mini-batch是具有代表性的,但是它還不完全是dataset,等於給網絡增加了隨機噪音。有一定的正則化效果。

3、Accelerating BN Networks
       提高了網絡的學習速度,每一層的輸入數據均值方差在一定的範圍內,使下一層網絡不必去適應輸入的變化,允許每一層進行獨立學習,有利於提高整個神經網絡的學習速度。

4、in some cases eliminating the need for Dropout.
       減少對dropout的使用
5、reduce overfitting
      降低過擬合,道理同2

四、BN的缺點有哪些

1、效果容易受batch size大小的影響。batch size越大,mini-batch的數據越有代表性,它的mean and variance越接近dataset的mean and variance。但是batch太大,內存不一定夠放。
2、難以在RNN中使用,RNN中更多的是使用Layer norm。

五、代碼實現

def batchnorm_forward(x, gamma, beta, bn_param):
    """
    Forward pass for batch normalization.

    During training the sample mean and (uncorrected) sample variance are
    computed from minibatch statistics and used to normalize the incoming data.
    During training we also keep an exponentially decaying running mean of the
    mean and variance of each feature, and these averages are used to normalize
    data at test-time.

    At each timestep we update the running averages for mean and variance using
    an exponential decay based on the momentum parameter:

    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var

    Note that the batch normalization paper suggests a different test-time
    behavior: they compute sample mean and variance for each feature using a
    large number of training images rather than using a running average. For
    this implementation we have chosen to use running averages instead since
    they do not require an additional estimation step; the torch7
    implementation of batch normalization also uses running averages.

    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - bn_param: Dictionary with the following keys:
      - mode: 'train' or 'test'; required
      - eps: Constant for numeric stability
      - momentum: Constant for running mean / variance.
      - running_mean: Array of shape (D,) giving running mean of features
      - running_var Array of shape (D,) giving running variance of features

    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    mode = bn_param['mode']
    eps = bn_param.get('eps', 1e-5)
    momentum = bn_param.get('momentum', 0.9)

    N, D = x.shape
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))

    out, cache = None, None
    if mode == 'train':
        #######################################################################
        # TODO: Implement the training-time forward pass for batch norm.      #
        # Use minibatch statistics to compute the mean and variance, use      #
        # these statistics to normalize the incoming data, and scale and      #
        # shift the normalized data using gamma and beta.                     #
        #                                                                     #                                          #
        # Note that though you should be keeping track of the running         #
        # variance, you should normalize the data based on the standard       #
        # deviation (square root of variance) instead!                        # 
        # Referencing the original paper (https://arxiv.org/abs/1502.03167)   #
        # might prove to be helpful.                                          #
        #######################################################################
      
        sample_mean = np.mean(x, axis=0)   #[D]
        sample_var = np.var(x, axis=0)  #[D]
        x_hat = (x - sample_mean) / np.sqrt(sample_var+eps)
        out = gamma*x_hat+beta
        cache = (gamma, x, sample_mean, sample_var, eps, x_hat)  #why is this
        
        #store the global mean and var
        running_mean = momentum* running_mean +(1-momentum)*sample_mean  
        running_var = momentum*running_var+(1-momentum)*sample_var
     
    elif mode == 'test':
        #######################################################################
        # TODO: Implement the test-time forward pass for batch normalization. #
        # Use the running mean and variance to normalize the incoming data,   #
        # then scale and shift the normalized data using gamma and beta.      #
        # Store the result in the out variable.                               #
        #######################################################################
        
        scale = gamma / np.sqrt(running_var+eps)
        shift = beta- scale*running_mean
        
        out = scale * x + shift
    else:
        raise ValueError('Invalid forward batchnorm mode "%s"' % mode)

    # Store the updated running means back into bn_param
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var

    return out, cache

def batchnorm_backward(dout, cache):
    """
    Backward pass for batch normalization.

    For this implementation, you should write out a computation graph for
    batch normalization on paper and propagate gradients backward through
    intermediate nodes.

    Inputs:
    - dout: Upstream derivatives, of shape (N, D)
    - cache: Variable of intermediates from batchnorm_forward.

    Returns a tuple of:
    - dx: Gradient with respect to inputs x, of shape (N, D)
    - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
    - dbeta: Gradient with respect to shift parameter beta, of shape (D,)
    """
    dx, dgamma, dbeta = None, None, None
    ###########################################################################
    # TODO: Implement the backward pass for batch normalization. Store the    #
    # results in the dx, dgamma, and dbeta variables.                         #
    # Referencing the original paper (https://arxiv.org/abs/1502.03167)       #
    # might prove to be helpful.                                              #
    ###########################################################################

    gamma, x, sample_mean, sample_var, eps, x_hat = cache
    N = x.shape[0]
    dx_hat = dout * gamma
    dvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)
    dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)
    dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmean
    dgamma = np.sum(x_hat * dout, axis = 0)
    dbeta = np.sum(dout , axis = 0)

    return dx, dgamma, dbeta
# test
np.random.seed(231)
N, D = 4, 5
x = 5 * np.random.randn(N, D) + 12
gamma = np.random.randn(D)
beta = np.random.randn(D)
dout = np.random.randn(N, D)
bn_param = {'mode': 'train'}
_, cache = batchnorm_forward(x, gamma, beta, bn_param)
dx, dgamma, dbeta = batchnorm_backward(dout, cache)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章