一、RNN的前向傳播結構
t時刻輸入: Xt 、St−1
t時刻輸出: ht
t時刻中間狀態: St
上圖是一個RNN神經網絡的時序展開模型,中間t時刻的網絡模型揭示了RNN的結構。可以看到,原始的RNN網絡的內部結構非常簡單。神經元A在t時刻的狀態僅僅是(t-1)時刻神經元狀態St−1,與(t)時刻網絡輸入Xt的雙曲正切函數的值;這個值不僅僅作爲該時刻網絡的輸出,也作爲該時刻網絡的狀態被傳入到下一個時刻的網絡狀態中,這個過程叫做RNN的正向傳播(forward propagation)
傳播中的數學公式(含參數)
上圖表示爲RNN網絡的完整的拓撲結構,以及RNN網絡中相應的參數情況。我們通過對t時刻網絡的行爲進行數學的推導。在如下的內容中,會出現線性狀態和激活狀態兩種表達,線性狀態將用∗號進行標註。
t時刻神經元狀態 :
St=ϕ(St∗)
St∗=(UXt+WSt−1)
t時刻的輸出狀態:
Ot=ψ(Ot∗)
Ot∗=VSt
我們該如何得到RNN模型中的U、V、W三個全局共享參數的具體值呢?在之後的RNN逆向傳播中可以得出具體的情況。
二、BPTT(隨時間變化的反向傳播算法)
1、 損失函數的選取,在RNN中一般選取交叉熵(Cross Entropy),表達式如下:
Loss=−i=0∑nyilnyi∗
上式爲交叉熵的標量的形式,yi是真實的標籤紙,yi∗是模型給出的預測值,在多維輸出值的時,則可以通過累加得出n維損失值。交叉熵在應用於RNN需進行微調:首先,RNN的輸出是向量的形式,沒有必要將所有的維度進行累加一起,直接把損失值用向量進行表達即可;其次,由於RNN模型是序列問題,因此其模型損失不能只是一個時刻的損失,應該包含全部N個時刻的損失。
因此RNN模型在t時刻的損失函數如下:
Losst=−[ytln(Ot)+(yt−1)ln(1−Ot)]
全部N個時刻的損失函數(全局損失)表達爲如下形式:
Loss=−t=1∑NLosst=−t=1∑N[ytln(Ot)+(yt−1)ln(1−Ot)]
2、 softmax函數的求導公式爲(下文用ψ表示)
ψ′(x)=ψ(x)(1−ψ(x))
3、 激活函數的求導公式爲(選取tanh(x)作爲激活函數)
ϕ(x)=tanh(x)
ϕ′(x)=(1−ϕ2(x))
4、 BPTT算法
注: 由於RNN模型與時間序列有關,所以使用Back Propagation Through Time(隨時間變化反向傳播的算法),但依舊遵循鏈式求導法則。在損失函數中,雖然RNN的額全局損失是與N個時刻有關的,但下面的推導僅涉及某個t時刻。
(1)求出t時刻下的損失函數關於Ot∗的微分:
∂Ot∗∂Lt=∂Ot∂Lt∗∂Ot∗∂Ot=∂Ot∂Lt∗∂Ot∗∂ψ(Ot∗)=∂Ot∂Lt∗ψ′(Ot∗)
(2)求出損失函數關於參數V的微分(需要(1)中的結論):
∂V∂Lt=∂(VSt)∂Lt∗∂V∂(VSt)=∂Ot∗∂Lt∗St=∂Ot∂Lt∗ψ′(Ot∗)∗St
因此,全局關於參數V的微分爲:
∂V∂L=t=1∑N∂V∂Lt=t=1∑N∂Ot∂Lt∗ψ′(Ot∗)∗St
(3)求出t時刻的損失函數關於St∗的微分:
∂St∗∂Lt=∂(VSt)∂Lt∗∂St∂(VSt)∗∂St∗∂St=∂Ot∗∂Lt∗V∗ϕ′(St∗)=∂Ot∂Lt∗ψ′(Ot∗)∗V∗ϕ′(St∗)
(4)求出t時刻的損失函數關於St−1的微分
∂St−1∗∂Lt=∂St∗∂Lt∗∂St−1∗∂St∗=∂St∗∂Lt∗∂St−1∗∂[Wϕ(St−1∗)+UXt]=∂St∗∂Lt∗Wϕ′(St−1∗)
(5)求出t時刻關於參數U的偏微分
注:因爲是時間序列模型,因此t時刻關於U
的微分與前(t-1)個時刻都相關,在具體計算時可以限定最遠回溯到前n個時刻,但在推導時需將(t-1)個時刻全部代入計算
∂U∂Lt=k=1∑t∂Sk∗∂Lt∂U∂Sk∗=k=1∑t∂Sk∗∂Lt∂U∂(WSk−1+UXk)=k=1∑t∂Sk∗∂Lt∗Xk
因此,全局關於U的損失偏微分爲:
∂U∂L=t=1∑N∂U∂Lt=t=1∑Nk=1∑t∂Sk∗∂Lt∂U∂Sk∗=t=1∑Nk=1∑t∂Sk∗∂Lt∗Xk
(6)求出t時刻關於參數W的偏微分(同上)
∂W∂Lt=k=1∑t∂Sk∗∂Lt∂W∂Sk∗=k=1∑t∂Sk∗∂Lt∂W∂(WSk−1+UXk)=k=1∑t∂Sk∗∂Lt∗Sk−1
因此,全局關於U的損失偏微分爲:
∂W∂L=t=1∑N∂W∂Lt=t=1∑Nk=1∑t∂Sk∗∂Lt∂W∂Sk∗=t=1∑Nk=1∑t∂Sk∗∂Lt∗Sk−1
(7)由於大多數的輸出爲softmax函數,我們在對Ot∗進行softmax運算後求導可得
ψ′(Ot∗)=Ot(1−Ot)
所以在Ot進行微分求偏導可得(採用交叉熵作爲損失函數)
∂Ot∂Lt=∂Ot−∂[∑t=1N[ytln(Ot)+(yt−1)ln(1−Ot)]=−(Otyt+1−Otyt−Ot)=−Ot(1−Ot)yt−Ot
∂Ot∂Lt∗ψ′(Ot∗)=−Ot(1−Ot)yt−Ot∗Ot(1−Ot)=Ot−yt
∂St∗∂Lt=∂Ot∂Lt∗ψ′(Ot∗)∗V∗ϕ′(St∗)=[V∗(Ot−yt)]∗[1−ϕ2(st∗)]=[V∗(Ot−yt)]∗[1−St2]
∂St−1∗∂Lt=∂St∗∂Lt∗Wϕ′(St−1∗)=∂St∗∂Lt∗W∗[1−St−12]
綜上:
∂V∂L=t=1∑N∂V∂Lt=t=1∑N(Ot−yt)∗St
其餘得類似
(8)我們逐步更新V,U,W三者得參數,直至它們收斂爲之
V:=V−η∗∂V∂L
U:=U−η∗∂U∂L
W:=W−η∗∂W∂L