RNN、LSTM和GRU

一、循環神經網絡

傳統的神經網絡並不能做到保持信息的持久性,RNN(Recurrent Neural Retwork) 解決了這個問題。RNN 是包含循環的網絡,允許信息的持久化。
RNN
在上面的示例圖中,神經網絡的模塊,,正在讀取某個輸入 ,並輸出一個值 。循環可以使得信息可以從當前步傳遞到下一步。

RNN 可以被看做是同一神經網絡的多次複製,每個神經網絡模塊會把消息傳遞給下一個。所以,如果我們將這個循環展開:
RNN
RNN 的關鍵點之一就是他們可以用來連接先前的信息到當前的任務上,例如使用過去的視頻段來推測對當前段的理解。

但是LSTM存在長期依賴(Long-Term Dependencies)問題,也就是當序列特別長的時候,RNN 會丟失之前的信息。於是提出了LSTM來改進這些問題。

同時RNN還存在梯度消失和梯度爆炸的問題。可以通過改變激活函數來解決,同時LSTM也可以解決梯度消失和梯度爆炸的問題。

1.1 標準RNN的前向輸出流程

下面這是一個RNN詳細的結構圖,其中各個符號的含義:x是輸入,h是隱層單元,o爲輸出y爲訓練集的標籤,L爲損失函數。這些元素右上角帶的t代表t時刻的狀態,其中需要注意的是,因策單元h在t時刻的表現不僅由此刻的輸入決定,還受t時刻之前時刻的影響。V、W、U是權值,同一類型的權連接權值相同。
RNN
前向傳播算法其實非常簡單,t時刻隱狀態h(t)h^{(t)}爲:h(t)=ϕ(Ux(t)+Wh(t1)+b)h^{(t)}=\phi(Ux^{(t)}+Wh^{(t-1)}+b)
其中ϕ()\phi()爲激活函數,一般來說會選擇tanh函數(注意這個tanh函數,它是引起梯度消失和梯度爆炸的原因,下面會細講),b爲偏置。

t時刻的輸出o(t)o^{(t)}就更爲簡單(c爲偏置):
o(t)=Vh(t)+co^{(t)}=Vh^{(t)}+c

t時刻模型的預測輸出y(t)y^{(t)}爲:
y(t)=σ(o(t))y^{(t)}=\sigma(o^{(t)})
其中σ()\sigma()爲激活函數,通常RNN用於分類,故這裏一般用softmax函數。

1.2 RNN的訓練方法—BPTT

BPTT(back-propagation through time)算法是常用的訓練RNN的方法,其實本質還是BP算法,只不過RNN處理時間序列數據,所以要基於時間反向傳播,故叫隨時間反向傳播。BPTT的中心思想和BP算法相同,沿着需要優化的參數的負梯度方向不斷尋找更優的點直至收斂。綜上所述,BPTT算法本質還是BP算法,BP算法本質還是梯度下降法,那麼求各個參數的梯度便成了此算法的核心。

RNN
再次拿出這個結構圖觀察,需要尋優的參數有三個,分別是U、V、W。與BP算法不同的是,其中W和U兩個參數的尋優過程需要追溯之前的歷史數據,參數V相對簡單隻需關注目前時刻t。
(1)那麼我們就來先求解參數V的偏導數:
L(t)V=L(t)o(t)o(t)V\frac{\partial L^{(t)}}{\partial V}=\frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V}
這個式子看起來簡單但是求解起來很容易出錯,因爲其中嵌套着激活函數函數,是複合函數的求道過程。RNN的損失也是會隨着時間累加的,所以不能只求t時刻的偏導,要把所有時刻的偏導都求出來再累加:
L=t=1nL(t)L=\sum_{t=1}^n L^{(t)}
LV=t=1nL(t)o(t)o(t)V\frac{\partial L}{\partial V}=\sum_{t=1}^n \frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V}

(2)W和U的偏導的求解由於需要涉及到歷史數據,其偏導求起來相對複雜,我們先假設只有三個時刻
那麼在第三個時刻 L對W的偏導數爲:
L(3)W=L(3)o(3)o(3)h(3)h(3)W+L(3)o(3)o(3)h(3)h(3)h(2)h(2)W+L(3)o(3)o(3)h(3)h(3)h(2)h(2)h(1)h(1)W\frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W}

相應的,L在第三個時刻對U的偏導數爲:
L(3)U=L(3)o(3)o(3)h(3)h(3)U+L(3)o(3)o(3)h(3)h(3)h(2)h(2)U+L(3)o(3)o(3)h(3)h(3)h(2)h(2)h(1)h(1)U\frac{\partial L^{(3)}}{\partial U}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial U}

可以觀察到,在某個時刻的對W或是U的偏導數,需要追溯這個時刻之前所有時刻的信息,這還僅僅是一個時刻的偏導數,上面說過損失也是會累加的,那麼整個損失函數對W和U的偏導數將會非常繁瑣。雖然如此但好在規律還是有跡可循,我們根據上面兩個式子可以寫出L在t時刻對W和U偏導數的通式
L(t)W=k=0tL(t)o(t)o(t)h(t)(j=k+1th(j)h(j1))h(k)W\frac{\partial L^{(t)}}{\partial W}=\sum_{k=0}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial W}
L(t)U=k=0tL(t)o(t)o(t)h(t)(j=k+1th(j)h(j1))h(k)U\frac{\partial L^{(t)}}{\partial U}=\sum_{k=0}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial U}

整體的偏導公式就是將其按時刻再一一加起來。注意這個累乘裏面是t時刻的h對t-1時刻的h求導。

1.3 梯度消失和梯度爆炸

前面說過激活函數是嵌套在裏面的,如果我們把激活函數放進去,拿出中間累乘的那部分:
j=k+1thjhj1=j=k+1ttanhWs\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot W_{s}

或是
j=k+1thjhj1=j=k+1tsigmoidWs\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{sigmoid^{'}}\cdot W_{s}

於是這個tanh的導數(或者sigmod的導數)就以累乘的形式參與到梯度的計算中去。但是我們來看看tanh的導數和sigmod的導數的特徵:

  • tanh的函數圖像和導數圖像:
    tanh

  • sigmoid的函數圖像和導數圖像:
    sigmoid

它們二者是何其的相似,都把輸出壓縮在了一個範圍之內。他們的導數圖像也非常相近,我們可以從中觀察到,sigmoid函數的導數範圍是(0,0.25],tanh函數的導數範圍是(0,1],他們的導數最大都不大於1。這就是會帶來幾個問題:
(1)如果 WsW_{s} 也是一個大於0小於1的值,使得tanhWs<1tanh' * W_s < 1,則當t很大時,梯度累乘的值就會趨近於0,和 (0.9*0.8)^50趨近與0是一個道理。
(2)同理當 WsW_{s} 很大時,具體指(比如 tanh=0.1tanh' = 0.1,而 Ws=99W_s=99,則相乘爲9.9),使得tanhWs>1tanh' * W_s > 1,則當t很大時,梯度累乘的值就會趨近於無窮。

這就是RNN中梯度消失和爆炸的原因。其實RNN的時間序列與深層神經網絡很像,在較爲深層的神經網絡中使用sigmoid函數做激活函數也會導致反向傳播時梯度消失,梯度消失就意味消失那一層的參數再也不更新,那麼那一層隱層就變成了單純的映射層,毫無意義了,所以在深層神經網絡中,有時候多加神經元數量可能會比多家深度好。但是tanh函數相對於sigmoid函數來說梯度較大,收斂速度更快且引起梯度消失更慢

sigmoid函數還有一個缺點,Sigmoid函數輸出不是零中心對稱。sigmoid的輸出均大於0,這就使得輸出不是0均值,稱爲偏移現象,這將導致後一層的神經元將上一層輸出的非0均值的信號作爲輸入。而關於原點對稱的輸入和中心對稱的輸出,網絡會收斂地更好

RNN的特點本來就是能“追根溯源“利用歷史數據,現在告訴我可利用的歷史數據竟然是有限的,這就令人非常難受,解決“梯度消失“是非常必要的。解決“梯度消失“的方法主要有:

  • 選取更好的激活函數
  • 改變傳播結構

關於第一點,一般選用ReLU函數作爲激活函數,ReLU函數的圖像爲:
RELU

左側恆爲1的導數避免了“梯度消失“的發生。但是容易導致“梯度爆炸“,設定合適的閾值可以解決這個問題。

但是如果左側橫爲0的導數有可能導致把神經元學死,出現這個原因可能是因爲學習率太大,導致w更新巨大,使得輸入的所有訓練樣本數據在經過這個神經元的時候,所有輸出值都小於0,從而經過激活函數Relu計算之後的輸出爲0,從此不梯度(所有梯度之和)再更新。所以relu爲激活函數,學習率不能太大,設置合適的步長(學習率)也可以有效避免這個問題的發生。

二、長短期記憶神經網絡

LSTM(Long short-term Memory):a very special kind of Recurrent Neural Retwork.長短期記憶(Long short-term memory, LSTM)是一種特殊的RNN,主要是爲了解決長序列訓練過程中的梯度消失和梯度爆炸問題。簡單來說,就是相比普通的RNN,LSTM能夠在更長的序列中有更好的表現

2.1 LSTM 內部結構

上面介紹的RNN可以用下圖表示,內部只有一個tanh激活函數:
rnn
LSTM 同樣是這樣的結構,但是重複的模塊擁有一個不同的結構。不同於單一神經網絡層,整體上除了h在隨時間流動,細胞狀態c也在隨時間流動,細胞狀態c就代表着長期記憶
lstm
現在,我們先來熟悉一下圖中使用的各種元素的圖標:
icon

  • 黃色的矩形是學習得到的神經網絡層
  • 粉色的圓形表示一些運算操作,諸如加法乘法
  • 黑色的單箭頭表示向量的傳輸
  • 兩個箭頭合成一個表示向量的連接
  • 一個箭頭分開表示向量的複製

LSTM 的關鍵就是細胞狀態,水平線在圖上方貫穿運行。細胞狀態類似於傳送帶。直接在整個鏈上運行,只有一些少量的線性交互。信息在上面流傳保持不變會很容易:
cell
LSTM 有通過精心設計的稱作爲“門”的結構來去除或者增加信息到細胞狀態的能力。門是一種讓信息選擇式通過的方法。他們包含一個 sigmoid 神經網絡層和一個 pointwise 乘法操作:
sigmoid

LSTM 擁有三個門,來保護和控制細胞狀態。

2.2 分步理解LSTM

LSTM內部主要有三個階段:

(1) 忘記階段
這個階段主要是對上一個節點傳進來的Ct1C_{t-1}進行選擇性忘記。簡單來說就是會 “忘記不重要的,記住重要的”。具體來說是通過計算得到的 ft(f表示forget,也就是下面的ft)來作爲忘記門控,來控制上一個狀態的Ct1Ct-1 哪些需要留哪些需要忘。輸出的ft是一個在 0 到 1 之間的數值,描述每個部分有多少量可以通過。0 代表“不許任何量通過”,1 就指“允許任意量通過”,小數就是以前百分之多少的內容記住。然後這個ft和Ct1C_{t-1}進行一個pointwise 乘法操作,從而達到遺忘的效果。
ft=σ(Wf[ht1,xt]+bf)f_t=\sigma(W_f \cdot [h_{t-1},x_t]+b_f)
forget

(2) 選擇記憶階段
將當前這個t階段的輸入xtx_t有選擇性地進行“記憶”到細胞CtC_t中。(上一步是對前一個輸入Ct1C_{t-1}進行選擇記憶)。哪些重要則着重記錄下來,哪些不重要,則少記一些,這裏的it充當了一個記憶門控的作用。要記住的內容暫存爲C~t\tilde C_t
it=σ(Wi[ht1,xt]+bi)i_t=\sigma(W_i \cdot [h_{t-1},x_t]+b_i)
C~t=tanh(Wc[ht1,xt]+bC)\tilde C_t=tanh(W_c \cdot [h_{t-1},x_t]+b_C)
information
具體記憶的方法如下:先計算iti_tC~t\tilde C_t,再將兩者相乘。

經過(1)(2)兩個步驟之後,我們就可以更新細胞狀態CtC_t了。我們有了要從Ct1C_{t-1}遺忘的和要從xtx_t記住的內容,顯而易見,把兩個內容相加就是更新之後的細胞狀態CtC_t
Ct=ftCt1+itC~tC_t=f_t*C_{t-1}+i_t * \tilde C_t
update
其實這裏的ftf_titi_t可以看做兩個權重,一個是遺忘(Ct1C_{t-1})權重,一個是記憶(xtx_{t})權重。

(3)輸出階段
最終,我們需要確定輸出什麼值。這個輸出將會基於我們的細胞狀態,但是也是一個(經過tanh)過濾後的版本。

  • 首先,我們運行一個 sigmoid 層來xtx_t哪些內容將輸出出去;
  • 接着,我們把(更新後的)細胞狀態CtC_t通過 tanh 進行處理(得到一個在 -1 到 1 之間的值)並將它和 sigmoid 門的輸出相乘;
  • 最終我們僅僅會輸出我們確定輸出的那部分。

ot=σ(Wo[ht1,xt]+bO)o_t=\sigma(W_o \cdot [h_{t-1},x_t]+b_O)
ht=ottanh(Ct)h_t=o_t * tanh(C_t)
o

上面三步對應了三個門控,這三個門雖然功能上不同,但在執行任務的操作上是相同的。他們都是使用sigmoid函數作爲選擇工具,tanh函數作爲變換工具,這兩個函數結合起來實現三個門的功能。

三個步驟裏的權重Wf,Wi,WoW_f,W_i,W_o都不一樣且都是要學習的。

再看一下輸出的CtC_thth_t:
Ct=ftCt1+itC~tC_t=f_t*C_{t-1}+i_t*\tilde C_t
ht=ottanh(Ct)h_t=o_t*tanh(C_t)

2.3 總結

以上,就是LSTM的內部結構。通過門控狀態來控制傳輸狀態,記住需要長時間記憶的,忘記不重要的信息;而不像普通的RNN那樣只能夠“呆萌”地僅有一種記憶疊加方式。對很多需要“長期記憶”的任務來說,尤其好用。
但也因爲引入了很多內容,導致參數變多,也使得訓練難度加大了很多。因此很多時候我們往往會使用效果和LSTM相當但參數更少的GRU來構建大訓練量的模型。

2.4 LSTM 如何避免梯度消失和梯度爆炸?

上面說了,RNN的梯度消失和爆炸主要是由這個 j=k+1thjhj1=j=k+1ttanhWs\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot W_{s} 引起的,對於LSTM同樣也包含這樣的一項,但是在LSTM中是這樣的:j=k+1thjhj1=j=k+1ttanhσ(Wfxt+bf)01\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot \sigma(W_f x_t+b_f)\approx0|1

很顯然裏面這個tanhσ(Wfxt+bf){tanh^{'}}\cdot \sigma(W_f x_t+b_f)相乘的結果不可能發生梯度消失和爆炸。

三、GRU

GRU(Gate Recurrent Unit)是循環神經網絡(Recurrent Neural Network, RNN)的一種。和LSTM(Long-Short Term Memory)一樣,也是爲了解決長期記憶和反向傳播中的梯度等問題而提出來的。

GRU和LSTM在很多情況下實際表現上相差無幾,那麼爲什麼我們要使用新人GRU(2014年提出)而不是相對經受了更多考驗的LSTM(1997提出)呢。因爲GRU實驗的實驗效果與LSTM相似,但是更易於計算。

3.1 GRU的結構

GRU的輸入輸出結構與普通的RNN是一樣的。只不過內部結構是在LSTM的基礎上優化了。內部結構圖如下:
gru

(1)首先介紹GRU的兩個門,分別是重置的門控rtr_t(reset gate) 和更新門控ztz_t(update gate) ,計算方法和LSTM中門的計算方法一致:
rt=σ(Wr[ht1,xt])r_t=\sigma(W_r \cdot [h_{t-1},x_t])
zt=σ(Wz[ht1,xt])z_t=\sigma(W_z \cdot [h_{t-1},x_t])

(2)然後是計算候選隱藏層 h~t\tilde h_t(candidate hidden layer) ,這個候選隱藏層和LSTM中的 CtC_t是類似,可以看成是當前時刻的新信息,其中 rtr_t 用來控制需要保留多少之前的記憶,比如如果 rtr_t 爲0,那麼 h~t\tilde h_t 只包含當前詞的信息:
h~t=tanh(W[rtht1,xt])\tilde h_t=tanh(W \cdot [r_t*h_{t-1},x_t])

h~t\tilde h_t 的計算按下面這樣看更清晰一些,黃色的線是rtht1r_t*h_{t-1},藍色的線是[rtht1,xt][r_t*h_{t-1},x_t],紅色的線是W[rtht1,xt]W \cdot [r_t*h_{t-1},x_t],然後再經過一層tanh:
h~
(3)最後 ztz_t 控制需要從前一時刻的隱藏層 ht1h_{t-1}遺忘多少信息,需要加入多少當前時刻的隱藏層信息h~t\tilde h_t,最後得到當前位置的隱藏層信息hth_t , 需要注意這裏與LSTM的區別是GRU中沒有output gate:
ht=zth~t+(1zt)ht1h_t=z_t*\tilde h_t+(1-z_t)*h_{t-1}


參考:
【1】RNN
【2】RNN梯度消失和爆炸的原因
【3】人人都能看懂的LSTM
【4】[譯] 理解 LSTM 網絡

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章