BP神經網絡中應用到了梯度下降法,所以先介紹下梯度下降法。梯度下降法
梯度下降法,也叫最速下降法,是求解無約束最優化問題的一種常用方法。標量場中,某一點上的梯度指向標量場增長最快的方向,梯度的長度是這個最大的變化率。梯度下降是迭代算法,每一步需要求解目標函數的梯度向量。由於沿着負梯度方向,函數值下降最快所以在迭代的每一步,以梯度負方向更新X的值。
求出第k+1次迭代值 :
其中, 是搜索方向,取負梯度方向 , 是步長,由一維搜索確定,即 使得
BP算法
假設
神經網絡的輸出層爲第K層
符號定義
: 第l層第j個節點的輸入。
: 第l-1層第i個節點到第l層第j個節點的權值。
:Sigmoid函數
:第l層第j個節點的偏置,亦有稱神經元內部閾值。
:第l層第j個節點的輸出Output
:輸出層第j個節點的目標值
神經元的激活函數採用:
(1)
給定訓練集( )和模型輸出 ,輸出層的均方誤差爲:
(2)
訓練的目標是通過調整權值 ,使誤差 最小化。BP算法基於梯度下降法,以目標的負方向梯度方向對參數進行調整。所以先求解梯度。
輸出層權值反饋調整
對於輸出層權值
,由(1)式知:
由於權值
僅影響輸出層的節點k一個節點,所以:
由於
,所以:
對於Sigmoid函數
,有
,所以:
因爲
,
,所以:
最終:
(3)
其中:
隱藏層權值反饋調整
對於輸出層權值
,由(1)式知:
由於權值
僅影響輸出層的所具節點,所以上式的∑是不可去除的,所以:
同樣 ,所以
變形:
其中
由於
,所以:
由於 , ,所以:
由於
最終可得:
或者:
(3)
其中:
偏置(閾值)反饋調整
輸出層:
對於輸出層閾值
,由(1)式知:
因爲
,所以:
最終:
(4)
其中:
隱藏層:
對於輸出層權值 ,由(1)式知:
變形:
其中
由於
,所以:
由於
或者:
(5)
其中:
綜上所述:
(6)
(7)
(8)
(9)
其中
(10)
(11)
BP算法過程:
-------------------------------------------------------------------------------------------------
輸入:訓練集D={
},學習率
過程:
1:在(0,1)範圍內隨機初始化網絡中所有鏈接權值和閾值
2:repeat
3: for all do
4: 根據當前參數和式(1),計算當前樣本的輸出 ;
5: 根據式(10計算神經元的梯度項 ;
6 根據式(11算神經元的梯度項 ;
8: endfor
9:until 達到停止條件
輸出:連接權和閾值確定的多層前饋神經網絡
-------------------------------------------------------------------------------------------------
簡單BP算法代碼:一個使用反向傳播訓練的神經網絡嘗試使用輸入去預測輸出:
0
|
0
|
1
|
0
|
1
|
1
|
1
|
1
|
1
|
0
|
1
|
1
|
0
|
1
|
1
|
0
|
import numpy as np
# sigmoid function
def nonlin(x,deriv=False):
if(deriv==True):
return x*(1-x)
return 1/(1+np.exp(-x))
# input dataset
X = np.array([ [0,0,1],
[0,1,1],
[1,0,1],
[1,1,1] ])
# output dataset
y = np.array([[0,0,1,1]]).T
# seed random numbers to make calculation
# deterministic (just a good practice)
np.random.seed(1)
# initialize weights randomly with mean 0
syn0 = 2*np.random.random((3,1)) - 1
for iter in xrange(10000):
# forward propagation
l0 = X
l1 = nonlin(np.dot(l0,syn0))
# how much did we miss?
l1_error = y - l1
# multiply how much we missed by the
# slope of the sigmoid at the values in l1
l1_delta = l1_error * nonlin(l1,True)
# update weights
syn0 += np.dot(l0.T,l1_delta)
print "Output After Training:"
print l1