循環神經網絡(RNN)之網絡結構解析

一、RNN的前向傳播結構

在這裏插入圖片描述
在這裏插入圖片描述

t時刻輸入: XtX_{t}St1S_{t-1}
t時刻輸出: hth_{t}
t時刻中間狀態: StS_{t}

上圖是一個RNN神經網絡的時序展開模型,中間t時刻的網絡模型揭示了RNN的結構。可以看到,原始的RNN網絡的內部結構非常簡單。神經元A在t時刻的狀態僅僅是(t-1)時刻神經元狀態St1S_{t-1},與(t)時刻網絡輸入XtX_t的雙曲正切函數的值;這個值不僅僅作爲該時刻網絡的輸出,也作爲該時刻網絡的狀態被傳入到下一個時刻的網絡狀態中,這個過程叫做RNN的正向傳播(forward propagation)

傳播中的數學公式(含參數)

在這裏插入圖片描述

上圖表示爲RNN網絡的完整的拓撲結構,以及RNN網絡中相應的參數情況。我們通過對t時刻網絡的行爲進行數學的推導。在如下的內容中,會出現線性狀態和激活狀態兩種表達,線性狀態將用*號進行標註。
t時刻神經元狀態
St=ϕ(St)S_t= {\phi}{(S{_t^*})}
St=(UXt+WSt1)S{_t^*}=(UX_t+WS_{t-1})
t時刻的輸出狀態
Ot=ψ(Ot)O_t=\psi{(O{_t^*})}
Ot=VStO{_t^*} = VS_t
我們該如何得到RNN模型中的U、V、W三個全局共享參數的具體值呢?在之後的RNN逆向傳播中可以得出具體的情況。

二、BPTT(隨時間變化的反向傳播算法)

1、 損失函數的選取,在RNN中一般選取交叉熵(Cross Entropy),表達式如下:
Loss=i=0nyilnyiLoss = -{\sum_{i=0}^{n}y_ilny_i^*}
上式爲交叉熵的標量的形式,yiy_i是真實的標籤紙,yiy_i^*是模型給出的預測值,在多維輸出值的時,則可以通過累加得出n維損失值。交叉熵在應用於RNN需進行微調:首先,RNN的輸出是向量的形式,沒有必要將所有的維度進行累加一起,直接把損失值用向量進行表達即可;其次,由於RNN模型是序列問題,因此其模型損失不能只是一個時刻的損失,應該包含全部N個時刻的損失。
因此RNN模型在t時刻的損失函數如下:
Losst=[ytln(Ot)+(yt1)ln(1Ot)]{Loss}_t = -[y_tln(O_t) + (y_t-1)ln(1-O_t)]
全部N個時刻的損失函數(全局損失)表達爲如下形式:
Loss=t=1NLosst=t=1N[ytln(Ot)+(yt1)ln(1Ot)]Loss = -{\sum_{t=1}^NLoss_t}= -{\sum_{t=1}^N[y_tln(O_t) + (y_t-1)ln(1-O_t)]}

2、 softmax函數的求導公式爲(下文用ψ\psi 表示
ψ(x)=ψ(x)(1ψ(x))\psi'(x)=\psi(x)(1-\psi(x))

3、 激活函數的求導公式爲(選取tanh(x)作爲激活函數)
ϕ(x)=tanh(x)\phi(x) = tanh(x)
ϕ(x)=(1ϕ2(x))\phi'(x)=(1-{\phi^2(x)})

4、 BPTT算法
注: 由於RNN模型與時間序列有關,所以使用Back Propagation Through Time(隨時間變化反向傳播的算法),但依舊遵循鏈式求導法則。在損失函數中,雖然RNN的額全局損失是與N個時刻有關的,但下面的推導僅涉及某個t時刻。
(1)求出t時刻下的損失函數關於OtO_t^*的微分:
LtOt=LtOtOtOt=LtOtψ(Ot)Ot=LtOtψ(Ot)\frac{\partial{L_t}}{\partial{O_t^*}} =\frac{\partial{L_t}}{\partial{O_t}} * \frac{\partial{O_t}} {\partial{O_t^*}}=\frac{\partial{L_t}}{\partial{O_t}} * \frac{\partial{\psi{(O_t^*)}}} {\partial{O_t^*}}=\frac{\partial{L_t}}{\partial{O_t}} * \psi'(O_t^*)
(2)求出損失函數關於參數V的微分(需要(1)中的結論):
LtV=Lt(VSt)(VSt)V=LtOtSt=LtOtψ(Ot)St\frac{\partial{L_t}}{\partial{V}} = \frac{\partial{L_t}}{\partial{(VS_t)}} * \frac{\partial{(VS_t)}} {\partial{V}}=\frac{\partial{L_t}}{\partial{O_t^*}} * S_t=\frac{\partial{L_t}}{\partial{O_t}} * \psi'(O_t^*)* S_t
因此,全局關於參數V的微分爲:
LV=t=1NLtV=t=1NLtOtψ(Ot)St\frac{\partial{L}}{\partial{V}}={\sum_{t=1}^{N}}\frac{\partial{L_t}}{\partial{V}}={\sum_{t=1}^{N}}\frac{\partial{L_t}}{\partial{O_t}} * \psi'(O_t^*)* S_t
(3)求出t時刻的損失函數關於StS_t^*的微分:
LtSt=Lt(VSt)(VSt)StStSt=LtOtVϕ(St)=LtOtψ(Ot)Vϕ(St)\frac{\partial{L_t}}{\partial{S_t^*}} = \frac{\partial{L_t}}{\partial{(VS_t)}} * \frac{\partial{(VS_t)}} {\partial{S_t}} * \frac{\partial{S_t}} {\partial{S_t^*}}=\frac{\partial{L_t}}{\partial{O_t^*}}*V*\phi'(S_t^*)=\frac{\partial{L_t}}{\partial{O_t}}*\psi'(O_t^*)*V*\phi'(S_t^*)
(4)求出t時刻的損失函數關於St1S_{t-1}的微分
LtSt1=LtStStSt1=LtSt[Wϕ(St1)+UXt]St1=LtStWϕ(St1)\frac{\partial{L_t}}{\partial{S_{t-1}^*}}=\frac{\partial{L_t}}{\partial{S_t^*}} *\frac{\partial{S_t^*}}{\partial{S_{t-1}^*}}= \frac{\partial{L_t}}{\partial{S_t^*}} *\frac{\partial{[W\phi(S_{t-1}^*)}+UX_t]}{\partial{S_{t-1}^*}} = \frac{\partial{L_t}}{\partial{S_t^*}} *W\phi'(S_{t-1}^*)
(5)求出t時刻關於參數U的偏微分
注:因爲是時間序列模型,因此t時刻關於U
的微分與前(t-1)個時刻都相關,在具體計算時可以限定最遠回溯到前n個時刻,但在推導時需將(t-1)個時刻全部代入計算
LtU=k=1tLtSkSkU=k=1tLtSk(WSk1+UXk)U=k=1tLtSkXk\frac{\partial L_t}{\partial U}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}\frac{\partial S_k^*}{\partial U}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}\frac{\partial ({WS_{k-1}}+UX_k)}{\partial U}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*X_k
因此,全局關於U的損失偏微分爲:
LU=t=1NLtU=t=1Nk=1tLtSkSkU=t=1Nk=1tLtSkXk\frac{\partial L}{\partial U}=\sum_{t=1}^{N}\frac{\partial L_t}{\partial U}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}\frac{\partial S_k^*}{\partial U}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*X_k
(6)求出t時刻關於參數W的偏微分(同上)
LtW=k=1tLtSkSkW=k=1tLtSk(WSk1+UXk)W=k=1tLtSkSk1\frac{\partial L_t}{\partial W}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}\frac{\partial S_k^*}{\partial W}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}\frac{\partial ({WS_{k-1}}+UX_k)}{\partial W}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*S_{k-1}
因此,全局關於U的損失偏微分爲:
LW=t=1NLtW=t=1Nk=1tLtSkSkW=t=1Nk=1tLtSkSk1\frac{\partial L}{\partial W}=\sum_{t=1}^{N}\frac{\partial L_t}{\partial W}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}\frac{\partial S_k^*}{\partial W}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*S_{k-1}
(7)由於大多數的輸出爲softmax函數,我們在對OtO_t^*進行softmax運算後求導可得
ψ(Ot)=Ot(1Ot)\psi'(O_t^*)=O_t(1-O_t)
所以在OtO_t進行微分求偏導可得(採用交叉熵作爲損失函數)
LtOt=[t=1N[ytln(Ot)+(yt1)ln(1Ot)]Ot=(ytOt+ytOt1Ot)=ytOtOt(1Ot)\frac{\partial L_t }{\partial O_t}=\frac{-\partial [\sum_{t=1}^N[y_tln(O_t) + (y_t-1)ln(1-O_t)]}{\partial O_t}=-(\frac {y_t}{O_t}+\frac{y_t-O_t}{1-O_t})=-{\frac{y_t-O_t}{O_t(1-O_t)}}
LtOtψ(Ot)=ytOtOt(1Ot)Ot(1Ot)=Otyt\frac{\partial L_t }{\partial O_t}*\psi'(O_t^*)=-{\frac{y_t-O_t}{O_t(1-O_t)}}*O_t(1-O_t)=O_t-y_t
LtSt=LtOtψ(Ot)Vϕ(St)=[V(Otyt)][1ϕ2(st)]=[V(Otyt)][1St2]\frac{\partial{L_t}}{\partial{S_t^*}} =\frac{\partial{L_t}}{\partial{O_t}}*\psi'(O_t^*)*V*\phi'(S_t^*)= [V*(O_t-y_t)]*[1-{\phi^2(s_t^*)}]= [V*(O_t-y_t)]*[1-S_t^2]
LtSt1=LtStWϕ(St1)=LtStW[1St12]\frac{\partial{L_t}}{\partial{S_{t-1}^*}}= \frac{\partial{L_t}}{\partial{S_t^*}} *W\phi'(S_{t-1}^*)= \frac{\partial{L_t}}{\partial{S_t^*}}*W*[1-S_{t-1}^2]

綜上:
LV=t=1NLtV=t=1N(Otyt)St\frac{\partial{L}}{\partial{V}}={\sum_{t=1}^{N}}\frac{\partial{L_t}}{\partial{V}}={\sum_{t=1}^{N}}(O_t-y_t) *S_t
其餘得類似
(8)我們逐步更新V,U,W三者得參數,直至它們收斂爲之
V:=VηLVV:=V-\eta*\frac{\partial L}{\partial V}
U:=UηLUU:=U-\eta*\frac{\partial L}{\partial U}
W:=WηLWW:=W-\eta*\frac{\partial L}{\partial W}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章