重溫LSTM和GRU

1. 寫在前面

最近用深度學習做一些時間序列預測的實驗, 用到了一些循環神經網絡的知識, 而當初學這塊的時候,只是停留在了表面,並沒有深入的學習和研究,只知道大致的原理, 並不知道具體的細節,所以導致現在復現一些經典的神經網絡會有困難, 所以這次藉着這個機會又把RNN, GRU, LSTM以及Attention的一些東西複習了一遍,真的是每一遍學習都會有新的收穫,之前學習過也沒有整理, 所以這次也藉着這個機會把這一塊的基礎內容進行一個整理和總結, 順便了解一下這些結構底層的邏輯。

這篇文章基於前面的重溫循環神經網絡(RNN), 通過前面的分析, 我們已經知道了RNN中的梯度消失和爆炸現在究竟是怎麼回事並且也知道了引起梯度消失和爆炸的原因, 而又由於梯度消失, 導致了RNN並不擅長捕捉序列的長期關聯, 所以基於這兩個問題, 導致現在RNN使用的並不是太多, 而是使用它的一些變體, 比如LSTM, GRU這些,所以這篇文章就主要圍繞着這兩個變體進行展開。

首先, 我們先從LSTM開始, 先看一下LSTM和RNN的不同, 然後整理LSTM的工作原理和計算細節, 然後基於這個原理分析一下LSTM是如何解決RNN存在的兩個問題的,爲了更方便理解LSTM底層,依然是基於numpy實現一下LSTM的前向傳播過程,明白了底層邏輯,那麼LSTM到底如何在實際中使用?這裏會簡單介紹一下keras裏面LSTM層的細節, 最後再整理GRU這塊, 這可以說是LSTM的一種簡化版, 那麼到底是如何簡化的, 與LSTM又會有哪些不同? 這篇文章會一一進行剖析。

大綱如下

  • RNN梯度消失怎麼破? LSTM來了
  • LSTM的工作原理和計算細節
  • LSTM是如何解決RNN存在的梯度消失問題的
  • LSTM前向傳播的numpy實現及keras中LSTM層簡單介紹
  • LSTM的變體之GRU一些細節
  • 總結

Ok, let’s go!

2. RNN梯度消失怎麼破? LSTM來了

上面文章提到過, 循環神經網絡(Recurrent Neural Network,RNN)是一種用於處理序列數據的神經網絡。相比一般的神經網絡來說,他能夠處理序列變化的數據。比如某個單詞的意思會因爲上文提到的內容不同而有不同的含義,RNN就能夠很好地解決這類問題。下面再來個RNN的圖回顧一下(會發現和之前的圖又是不一樣, 好多種畫法, 但是萬變不離其宗, 原理不會變,哈哈):
在這裏插入圖片描述
上一篇文章已經詳細分析了這種網絡的工作原理和計算方面的細節, 這裏就不再過多贅述, 這裏看一點新的東西, 就是序列依賴的問題, 上一篇文章中只是提到了循環網絡一個很重要的作用就是能夠捕捉序列之間的依賴關係, 而原理就是RNN在前向傳播的時候時間步之間有隱藏狀態信息的傳遞, 這樣反向傳播修改參數的時候, 前面時刻的一些序列信息會起到一定的作用,從而使得後面某個時刻的狀態會捕捉到前面時刻的一些信息。 這在語言模型中非常常見。

比如我有個語言模型, 該模型根據前面的詞語來預測下一個單詞要預測一句話中的一部分, 如果我們試圖預測“the clouds are in the sky”的最後一個單詞, 這時候模型會預測出sky, 因爲RNN會利用過去的歷史信息clouds
在這裏插入圖片描述
這是一種局部的依賴, 即在相關信息和需要該信息的距離較近的時候,RNN往往工作的效果還可以, 但如果是吳恩達老師舉得那個例子:The cat, which already ate…, was full. 如果是要預測後面的這個was, 我們的語言模型這時候得考慮更多的上下文信息, 就不能是單單局部的信息了, 得需要從最開始獲取cat的信息, 這種情況就屬於相關信息和需要該信息的地方距離非常遠。就是下面這種情況:
在這裏插入圖片描述
這時候, 我們的RNN表現的就不是那麼出色了, 至於原因, 上一篇文章中我們分析了一點, 很重要的一點就是梯度的消失, 也就是時間步一旦很長, 就會出現連乘現象, 在反向傳播的時候,這種連乘很容易會導致梯度消失, 一旦梯度消失, 後面的參數更新就無法再獲取到前面時刻的關鍵信息,所以“長依賴”這個問題, 在RNN中是沒法很好處理的。

那麼, LSTM就來了, 這個東西其實不是最新的了,1997年的時候就引入了, 並且在各種各樣的工作中工作效果不錯,也廣泛被使用, 雖然現在可能是Attention的天下了,甚至超越了LSTM, 但是LSTM依然可以解決很多的問題,是一個非常有力的工具,並且學習好LSTM, 對於理解Attention可能也會起到幫助, 總之,我覺得LSTM是肯定需要掌握的,哈哈。

LSTM的全稱是Long short-term memory(長短期記憶), 是一種特殊的RNN網絡, 該網絡的設計是爲了解決RNN不能解決的長依賴問題, 所以首先知道它是幹啥用的? 那麼它是如何做到的呢? 那麼我們就需要對比一下LSTM和RNN的結構, 看看它到底改了什麼東西才變得這麼強大的呢?

循環神經網絡都具有神經網絡的重複模塊鏈的形式, 標準的RNN中,該重複模塊將具有非常簡單的結構,例如單個tanh層。標準的RNN網絡如下圖所示
在這裏插入圖片描述
而LSTM既然是RNN網絡, 那麼也是採用的這種鏈式結構, 而與RNN不同的是每一個單元內部的運算邏輯, 下面先宏觀上看一個LSTM的結構圖, 在後面的運算細節那更能夠看出這種運算邏輯:
在這裏插入圖片描述
很明顯可以看到, LSTM與RNN相比,其實整體鏈式結構是沒有改變的, 改變的是每個單元的內部的計算邏輯, LSTM這裏變得複雜了起來, 而正式因爲這種複雜, 才使得LSTM解決了RNN解決不了的問題, 比如梯度消失, 比如長期依賴。

下面就看看LSTM的原理和計算細節。

3. LSTM的工作原理和計算細節

所謂LSTM的工作原理,LSTM其實是在做一個這樣的事情, 先嚐試白話的描述一下, 然後再分析它是怎麼實現。

我們前面說過,LSTM要解決的問題就是一種長期依賴問題, 也就是如果序列長度很長, 後面的序列就無法回憶起前面時刻序列的信息, 這樣就很容易導致後面序列的預測出現錯誤,就跟人大腦一樣, 如果時間很長, 就會出現遺忘一樣, 記不清之前的一些事情,不利於後面的決策了。而出現這種情況的原因,就是我們在記憶的過程中, 干擾信息太多,記住了一些對後面決策沒有用的東西, 時間一長, 反而把對後面決策有用的東西也忘掉了。

RNN其實也是一樣, 普通的RNN後面更新的時候, 要回憶前面所有時刻的序列信息,往往就導致回憶不起來(梯度消失), 而我們知道, 對於未來做某個決策的時候, 我們並不需要回憶前面發生過的所有的事情,同理,對於RNN來說, 我要預測的這個單詞需要考慮的上下文也並不是前面所有序列都對我當前的預測有用, 就比如上面的那個例子, 我要預測was, 我只需要最前面的cat即可, 中間那一串which巴拉巴拉的, 對我的預測沒有用, 所以我預測was根本沒有必要記住which的這些信息, 只需要記住cat即可, 這個在普通的RNN裏面是沒法做到的(不懂得可以看看它的前向傳播過程), 它根本沒有機會做出選擇記憶, 而LSTM的核心,就是它比RNN, 多了一個可選擇性的記憶cell, 在LSTM的每個時間步裏面,都有一個記憶cell,這個東西就彷彿給與了LSTM記憶功能, 使得LSTM有能力自由選擇每個時間步裏面記憶的內容, 如果感覺當前狀態信息很重要, 那麼我就記住它, 如果感覺當前信息不重要, 那麼我就不記, 繼續保留前一時刻傳遞過來的狀態, 比如cat的那個, 在cat的時刻,我把這個狀態的信息保留下來, 而像which那些, 我不保留,這樣was的時候就很容易看到cat這個狀態的信息,並基於這個信息更新, 這樣就能夠進行長期依賴的學習了。

上面就是LSTM一個宏觀工作原理的體現, 當然還有一些細節,比如這個記憶是怎麼進行選擇的, 這個記憶是怎麼在時間步中傳遞的, 又是怎麼保持的等, 下面從數學的角度詳細的說說:

首先, 是那條記憶線到底在單元裏面長什麼樣子:
在這裏插入圖片描述
LSTM的關鍵就是每個時間步之間除了隱藏狀態hth_t的傳遞之外,還有這麼一條線貫穿整條鏈(可以看上面的鏈狀圖), 這個東西就彷彿一條傳動帶, 幫助後面時刻的序列回憶前面某些時刻的序列信息。 比如上面的例子, 把cat這一時刻的狀態信息存儲到cell裏面, 然後就可以通過這條鏈子一直傳遞到was那, 中間可以忽略掉那些干擾信息, 這樣就保證了長期依賴。 這就是記憶如何進行的傳遞, 是通過了一條這樣的cell鏈子。

那麼, LSTM是怎麼做到自由選擇記憶的東西的呢?這個就是LSTM裏面那幾個門發生的作用了, LSTM的cell狀態存儲是由被稱爲門的結構精細控制, 門是一種讓信息可選地通過的方法。它們由一個sigmoid神經網絡層和一個點乘操作組成。
在這裏插入圖片描述
這裏我標出來了, 看到這三個門了嗎?那麼就看看這三個門是如何起作用的, 首先, 我們解決另一個問題,就是cell裏面到底存儲的是什麼東西, 看個圖:
在這裏插入圖片描述
看上面這個圖, 右邊是CtC_t的更新公式, 可以發現, 這個CtC_t也就是我們要記住的東西,其實會包含兩部分, 第一部分是Ct1C_{t-1}, 這個表示的是前面時刻記住的信息, 比如預測was的例子, 這個Ct1C_{t-1}, 就可能表示前面的cat信息, 而後面的C~t\tilde{C}_{t}, 這個表示的是當前時刻狀態的信息(後面會看到公式),這裏就是was這個時刻的狀態, 就是說在某個時刻LSTM的記憶, 由當時時刻輸入的狀態信息和前面時刻輸入的狀態信息兩部分組成, 而這兩部分又不是必須要記憶的, 因爲我們看到了Ct1,C~tC_{t-1}, \tilde{C}_{t}前面還有一個ft,itf_t, i_t, 這倆都是0-1之間的值,就是爲了控制當前時刻的記憶有多少是來自於前面的時刻, 有多少是來自當前的時刻, 這個選擇權交給網絡本身。 當網絡感覺某個時刻的狀態需要記住,比如cat時刻, 那麼就讓iti_t爲1, ftf_t爲0, 這樣當前時刻的記憶單元就記住了目前的值, 而到了which時刻, 網絡覺得which時刻的輸入狀態不用記住, 那麼ftf_t爲1, iti_t爲0, 網絡的記憶依然是前面時刻的狀態信息。 這就是LSTM宏觀原理在數學形式上的表現。

那麼上面既然提到了ft,itf_t, i_t, 這倆到底是什麼東西呢? 這兩個就是門控了, 首先是左邊的門,看下圖:

在這裏插入圖片描述
最左邊這個叫做遺忘門(forget gate), 這個門決定着我們還需不需要記住前面狀態的信息,即當前時刻的記憶狀態有多少是來自於前面的記憶,比如對於一個基於上文預測最後一個詞的語言模型。cell的狀態可能包含當前主題的信息,來預測下一個準確的詞。而當我們得到一個新的語言主題的時候,我們會想要遺忘舊的主題的記憶,應用新的語言主題的信息來預測準確的詞。 右邊是它的計算公式, 輸入是ht1,xth_{t-1}, x_t, 由於是sigmoid函數, 輸出一個0和1之間的數, 1代表我們保留之前記住的信息, 0代表我們不用記住前面的信息。 動圖看一下運算過程:
在這裏插入圖片描述

右邊的一個門, 如下圖:
在這裏插入圖片描述
這個門叫做輸入門或者更新門, 名字不重要, 幹什麼纔是重要的, 這個門就是控制當前時刻的記憶有多少會來自於當前時刻的輸入本身, 因爲這個 C~t=tanh(WC[ht1,xt]+bC)\tilde{C}_{t}=\tanh \left(W_{C} \cdot\left[h_{t-1}, x_{t}\right]+b_{C}\right)這個公式應該會眼熟, 其實RNN那個地方的隱藏狀態的更新就是這個公式:
在這裏插入圖片描述
只不過LSTM這裏是將前面的那個加法改成了向量乘積的形式, 所以之類就很容易理解這個C~t\tilde{C}_{t}的含義了, 就是當前時刻隱藏狀態的信息。有了隱藏狀態的信息, 又有了前面某些時刻的信息, 又有了兩個門控制記憶的量, 那麼記憶cell的更新就很容易了, 下面的方程就非常容易理解了吧:
Ct=ftCt1+itC~tC_{t}=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t}
所以LSTM這個名字可能看起來很嚇人, 但可能是一個紙老虎。 依然是看一下運算過程:
在這裏插入圖片描述
通過上面的兩個門,我們就可以把cell更新到一個我們想要的狀態了,
在這裏插入圖片描述

但是光更新這個東西是沒有意義的啊, 因爲我們分析了was這個時候, 要記住cat的狀態, 但記住的目的是要進行預測, 所以說我們的cell是爲當前時刻的輸出服務的。

下面就看看輸出部分到底是個啥?
在這裏插入圖片描述
這纔是LSTM自由選擇記憶之後的目的, 就是爲了能有一個更好的輸出。 這裏首先是一個輸出門, 依然是一個sigmoid, 取值0-1, 這個是控制我們的輸出有多少是來自於我們的記憶,並不一定是全部的記憶哦。 使得LSTM更靈活了,連輸出都可以進行選擇了。這個意思差不多就是雖然我權衡了一下前面的狀態信息和當前的狀態信息, 更新了我的記憶, 但是這個記憶我不一定要全用上, 用一部分就可以搞定當前的預測。 並且如果我發現我更新的記憶對當前的預測並沒有用, 反而會效果更差,這時候我還可以選擇不用這個記憶, 所以非常的靈活。
在這裏插入圖片描述
這就是LSTM的原理和計算細節了, 通過LSTM和RNN對比的方式再來總結一下LSTM:
在這裏插入圖片描述
看這個對比就能發現, LSTM比RNN更加複雜, RNN這個在前向傳播的時候, 是記住了每個時刻的狀態信息, 然後往後傳,這種網絡帶來的結果就是易發生梯度消失,無法捕捉長期依賴, 因爲傳遞的過程中有一些干擾信息, 導致後面時刻參數更新沒法借鑑距離遠的前面時刻的值。
而LSTM在記憶這方面更加的靈活, 長短期記憶嘛, 功能如其名,就是既可以長期記憶也可以短期記憶,它在RNN的基礎上增加了自由的選擇記憶功能, 也就是會有一個記憶cell, 這裏面會只存儲和當前時刻相關的一些重要信息, 畢竟每個時刻關注的上下文點可能不一樣, 這個交給網絡自己選擇, 光有cell也不能起到好作用, 還得有兩個門協助它完成選擇和過濾的功能, 所以遺忘門幫助它衡量需要記住多少前面時刻的狀態信息, 更新門幫助它衡量需要記住當前時刻的多少狀態信息, 這倆一組合就是比較理想的記憶了。 但是即使是這樣, LSTM依然不放心把這個記憶作爲輸出, 又加入了一個輸出門, 來自由的選擇我用多少記憶的信息作爲最後的輸出, 所以LSTM有了這三個門, 有了記憶cell, 使它變得更加的靈活, 既可以捕捉短期依賴, 也可以捕捉長期依賴, 並且也緩解了梯度消失(後面會分析)。

下面我們就來看看LSTM是怎麼解決梯度消失的問題的。

4. LSTM是如何解決RNN存在的梯度消失問題的

在上一篇文章中, 我們詳細分析了RNN爲什麼會存在梯度消失現象, 本質上就是因爲反向傳播的時候, 有j=k+1tSjSj1\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}, 因爲這個連乘, 纔會有梯度消失或者爆炸現象,當然這裏的梯度消失現象不是說後面時刻參數更新的時候梯度爲0, 而是說後面時刻梯度更新的時候, 前面更遠時刻的序列對參數更新是起不到作用的,梯度被近距離梯度主導, 所以才說RNN無法捕捉長期依賴。 而解決上面這個問題的根本,那就是讓這個連乘保持一個常量。 這個是怎麼做到的呢?

我們先來看看這個LSTM裏面那個參數相當於RNN裏面的這個Sj,Sj1S_j, S_{j-1}, 在RNN中, 這兩個表示的是某個時刻當前的隱藏狀態與前一個隱藏狀態, 在LSTM中, 哪個參數是這個作用呢? 很明顯,就是這裏的Cj,Cj1C_j, C_{j-1},如果你說aj,aj1a_j, a_{j-1}, 那麼看看公式就會知道, 這裏的最終隱態值與RNN中的隱態值扮演的角色差異還是很大的, 還不如Ct,Ct1C_t, C_{t-1}, 爲什麼呢?

由LSTM的結構可知, 在每個迭代週期, CtC_t是需要不斷更新的, 一部分是由Ct1C_{t-1}演化而來, 一部分是本時刻加入的新信息, 這個其實就和RNN的SjS_j類似(因爲我這裏可能符號總是換, 希望能明白我在說啥哈哈, 這裏的S和a,有時候還有h, 都表示的隱藏狀態的值, 而C表示的是記憶單元cell, S一般是花書上的表示方法, 而a是吳恩達老師喜歡的表示, h是李宏毅老師喜歡的方式,其實是一個東西, 符號不同而已)。

而如果根據反向傳播把LSTM的梯度結構展開, 也會包含連乘項, 正是這裏的(j=k+1tCjCj1)\left(\prod_{j=k+1}^{t} \frac{\partial C_{j}}{\partial C_{j-1}}\right), 所以我們在這裏研究一下這一項, 因爲LSTM的反向傳播會有很多條路, 推導公式也非常的複雜, 就不在這裏花時間推這個東西了, 後面會放出一個數學公式推導的鏈接, 感興趣的可以看一下,自己去推。 我們就研究一下這一條路, 就可以看出LSTM的偉大之處, 先上公式:
it=σ(Wi[ht1,xt]+bi)ft=σ(Wf[ht1,xt]+bf)ot=σ(Wo[ht1,xt]+bo)C~t=tanh(WC[ht1,xt]+bC)Ct=ftCt1+itC~tht=ottanh(Ct)\begin{aligned} i_{t} &=\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right) \\ f_{t} &=\sigma\left(W_{f} \cdot\left[h_{t-1}, x_{t}\right]+b_{f}\right) \\ o_{t} &=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right) \\ \tilde{C}_{t} &=\tanh \left(W_{C} \cdot\left[h_{t-1}, x_{t}\right]+b_{C}\right) \\ C_{t} &=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t} \\ h_{t} &=o_{t} * \tanh \left(C_{t}\right) \end{aligned}
回憶一下上面LSTM中的CtC_t,它是 ftf_t(遺忘門)、 iti_t(輸入門)和 C~t\tilde{C}_{t} (候選單元狀態)的函數,而這些變量又都是 Ct1C_{t-1}的函數(因爲它們都是ht1h_{t-1}的函數)。通過多變量的鏈式法則,我們得到:
CtCt1=Ctftftht1ht1Ct1+Ctititht1ht1Ct1+CtC~tC~tht1ht1Ct1+CtCt1\begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}}=\frac{\partial C_{t}}{\partial f_{t}} \frac{\partial f_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial i_{t}} \frac{\partial i_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+ \frac{\partial C_{t}}{\partial \tilde{C}_{t}} \frac{\partial \tilde{C}_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}} \end{aligned}
把上面的導數化簡出來:
CtCt1=Ct1σ()Wfot1tanh(Ct1)+C~tσ()Wiot1tanh(Ct1)+ittanh()WCot1tanh(Ct1)+ft\begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}}=& C_{t-1} \sigma^{\prime}(\cdot) W_{f} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) +\tilde{C}_{t} \sigma^{\prime}(\cdot) W_{i} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) +i_{t} \tanh ^{\prime}(\cdot) W_{C} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right)+f_{t} \end{aligned}
而加上連乘符號, 也就是個這樣(j=k+1tCjCj1)\left(\prod_{j=k+1}^{t} \frac{\partial C_{j}}{\partial C_{j-1}}\right), 即這個連乘就是在k個時間步的反向傳播過程, 這個其實和普通的RNN差不多, 拿過那個公式來看一下就明白了:
j=k+1tSjSj1=j=k+1ttanhWs \prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} W_{s}
下面就看看這兩個的區別, RNN上一篇文章已經分析了, 這裏的tanhWs\tanh ^{\prime} W_{s}, 始終是一個大於1或者是[0,1]之間的數, 由WsW_s具體決定, 那麼連乘之後就可以引起梯度消失或者爆炸, 而LSTM中的這個偏導數CtCt1\frac{\partial C_{t}}{\partial C_{t-1}}, 根據後面這一長串, 我們會發現這個數在任何時間步都可以取大於1或者落在[0,1]之間的數, 所以即使這裏連乘, 也不一定會梯度消失或者爆炸, 因爲這個地方並不是取決於某個參數的大小, 而是很多個參數共同決定, 並且這裏的這些參數ft,ot,it,C~tf_t, o_t, i_t, \tilde C_t都是通過網絡學習來設置的(取決於當前輸入和隱狀態)。 因此, 這樣的網絡通過調節對應門的值來決定何時讓梯度消失, 何時保留梯度, 這就是LSTM超級牛的地方。上一篇文章中我們有個例子, 直接截圖過來了:
在這裏插入圖片描述
如果換成LSTM的話, t=20的時候參數更新的公式中,後面那些就不一定都是0了, RNN的時候是0, 是因爲越往前, 連乘越厲害, 導致了梯度消失, 前面時刻的信息對於t=20的時候不起作用。 而LSTM的話, 由於LSTM會自動控制CtCt1\frac{\partial C_{t}}{\partial C_{t-1}}的大小, 所以即使這裏連乘,也不一定會出現梯度消失, 當LSTM感覺前面某個時刻的信息不重要, 比如was時刻的時候會覺得which巴拉巴拉那一些都不重要, 這時候就可以讓這個連乘等於0, 把which這一些給過濾掉,而如果感覺某個時刻重要, was時刻感覺cat時刻信息重要, 那麼就通過調控各個參數, 使得這裏的連乘不是0, 這樣cat時刻的信息就對當前was時刻的參數更新起到了幫助。 這裏就變成了這樣的一個感覺:
L20Wx=L20O20O20C20C20Wx+L20O20O20C20C20C19C19Wx+0....+0+L20O20O20C20(j=k+1tCjCj1)C2Wx+0\begin{aligned} \frac{\partial L_{20}}{\partial W_{x}}=\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial C_{20}} \frac{\partial C_{20}}{\partial W_{x}}+\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial C_{20}} \frac{\partial C_{20}}{\partial C_{19}} \frac{\partial C_{19}}{\partial W_{x}} + 0....+ 0 + \frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial C_{20}}(\prod_{j=k+1}^{t} \frac{\partial C_{j}}{\partial C_{j-1}})\frac{\partial C_{2}}{\partial W_{x}}+ 0\end{aligned}
這裏也更加看到了門函數的強大功能, 門函數賦予了網絡決定梯度消失程度的能力, 以及能夠在每一個時間步設置不同的值, 它們的值是當前的輸入和隱藏狀態的習得函數。 當然這裏還有起作用的一個東西就是那一長串裏面的加法運算, 這種加法運算不想乘法那麼果斷(一個0就整體0), 加性的單元更新狀態使得導數表現得更加“良性”。

當然這裏還有個細節就是LSTM的反向傳播並不是只有C這一條路,其實在其他路上依然會有梯度消失或者梯度爆炸的現象發生, 但LSTM只要改善了一條路徑上的梯度, 就拯救了總體的遠距離的依賴捕捉。至於詳細的反向傳播算法推導, 下面的鏈接給出了一篇, 當然下面的numpy實現LSTM的前向傳播和反向傳播的過程也稍微涉及一點。

5. LSTM前向傳播的numpy實現及keras中LSTM層簡單介紹

這裏分兩塊, 第一塊是用numpy簡單的實現一下LSTM的前向傳播和反向傳播, 這樣可以更好的弄清楚上面公式中各個變量的維度變化和LSTM的底層計算原理。 第二塊是keras的LSTM層, 會介紹實際中如何使用LSTM。

5.1 LSTM的前向傳播的numpy實現

關於LSTM的前向傳播,同樣我們需要先從單個的單元進行分析
在這裏插入圖片描述
右邊是前向傳播的公式, 看左邊的示意圖我們發現, 該單元的輸入是xt, a_prev, c_prev, 輸出是ct, at, yt_pred。 依然假設每個時間步我們輸入10個樣本, input_dim是3, units是5, 那我們的輸入(3, 10), a_prev是(5, 10), c_prev(5, 10)(這倆其實就和DNN那的輸出一樣, 輸入是(3,10), units是5, 那麼輸入和第一層之間的W就是(5, 3), 那麼WX之後的a就是(5,10)), 下面主要看看每個門中參數的維度:

  • 輸入門參數: Wf(units, units+input_dim), 這是因爲上面先把a_prev和xt羅列了一下, (5, 10)和(3,10)第一維度拼接, 就是(8,10), 所以這裏的Wf是(5, 8), 這樣兩者一乘纔是(5,10)。 這裏的bf是(5,1), 通過廣播之後,得到的輸出依然是(5,10)
  • 更新門參數: 分析和上面同理, 所以Wi(5, 8), bi(5, 1)
  • 輸出門參數: Wo(5, 8), bo(5,1)
  • C~t: Wc(5,8), bc(5,1)
  • 而輸出Wy, 這個的維度就是(n_y, units), by(n_y, 1), 因爲這個和最終的輸出有關了。
  • 輸出a_t: (5, 10), c_t(5, 10)

所以我們會發現這個Ct的維度是(5, 10), 也就是每個樣本在每個神經元都有自己的記憶, 並且互不影響。
基於上面的分析, 就可以實現一步cell的前向傳播了:

def lstm_cell_forward(xt, a_prev, c_prev, parametes):
	# 得到參數
	Wf = parameters["Wf"]
    bf = parameters["bf"]
    Wi = parameters["Wi"]
    bi = parameters["bi"]
    Wc = parameters["Wc"]
    bc = parameters["bc"]
    Wo = parameters["Wo"]
    bo = parameters["bo"]
    Wy = parameters["Wy"]
    by = parameters["by"]

	# 得到輸入和輸出維度
	input_dim, m = xt.shape
	n_y, units = Wy.shape

	# 拼接a和x
	concat = np.zeros([units+input_dim, m)
	concat[:units, :] = a_prev
	concat[units:, :] = xt

	# 根據公式前向傳播
	ft = sigmoid(np.dot(Wf, concat) + bf)
	it = sigmoid(np.dot(Wi, concat) + bi)
	cct = np.tanh(np.dot(Wc, concat) + bc)
	c_t = ft * c_prev + it * cct
	ot = sigmoid(np.dot(Wo, concat) + bo)
	a_t = ot * np.tanh(ct)

	yt_pred = softmax(np.dot(Wy, a_t) + by)

	# 存一下結果
	cache = (a_t, c_t, a_prev, c_prev, ft, it, cct, ot, xt, parameters)
	return a_t, c_t, yt_pred, cache

使用的時候, 按照維度初始化這些參數, 然後傳入即可得到一個時間步的輸出信息。 有了一個時間步的輸出信息, 多個時間步無非就是一個循環:
在這裏插入圖片描述
這裏參數的維度沒有變化, 但是輸入需要加上時間步的信息, 也就變成了3維, (input_dims, m, T_x)。 同理的這裏的a, y, c也都變成了3維(units, m, T_x), (n_y, m, T_x), 因爲每個時間步都會有a, y, c的輸出

def lstm_forward(x, a0, parameters):
	caches = []

	# 獲取輸入和輸出維度
	input_dim, m, T_x = x.shape
	n_y, units = parameters['Wy'].shape

	# 初始化輸入和輸出
	a = np.zeros((units, m, T_x))
	c = np.zeros((units, m, T_x))
	y = np.zeros((n_y, m, T_x))

	# 初始化開始的a c
	a_next = a0
	c_next = np.zeros([units, m])   # 初始記憶爲0

	# 前向傳播
	for i in range(T_x):
		a_next, c_next, yt, cache = lstm_cell_forward(x[:, :, t], a_next, c_next, parameters)
		a[:, :, t] = a_next
		c[:, :, t] = c_next
		y[:, :, t] = yt

		caches.append(cache)

	caches = (caches, x)
	return a, y, c, caches

至於LSTM的反向傳播底層, 這裏也不多說了,這個比較複雜, 大部分時間都是在求導, 而實際使用的時候, 比如keras,Pytorch, tf等其實都把反向傳播給實現了, 我們並不需要自己去寫。 所以我們重點需要知道的是在實際中LSTM到底應該怎麼用。

下面就拿最簡單實用的keras的LSTM舉例。

5.2 Keras的LSTM

keras裏面搭建一個LSTM網絡非常簡單, LSTM層的表示如下:

keras.layers.recurrent.LSTM(units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0)

這裏面有幾個核心的參數需要說一下, 其實RNN那個地方也作了鋪墊:

  • units: 這個指的就是隱藏層神經元的個數, 也是該層的輸出維度
  • input_dim: 這個是輸入數據的特徵數量, 當該層作爲模型首層時, 就需要指定這個
  • return_sequences:布爾值,默認False,控制返回類型。若爲True則返回整個序列,否則僅返回輸出序列的最後一個輸出, 這個也就是說如果是True, 就返回每個時間步的輸出, 而False,只返回最後一個時間步的輸出。 這個參數來自於LSTM的父類, 這裏的輸出指的是hidden state裏面的值, 也就是上面符號裏面的h或者a。
  • return_state: 默認爲False, 表示是否返回輸出之外的最後一個狀態, 這個和return_sequences不一樣, 最後一個狀態, 其實是包含兩個值的, 一個是hidden state的值, 一個是cell state的值,也就是h和c。
    • return_sequences的True表示返回所有時間步中的h, False表示返回最後一個時間步的h
    • return_state的True表示返回最後一個時間步的h和t, False表示返回。
      詳細的看下面的keras中的LSTM內部機制代碼理解的鏈接
  • timesteps: 這個就是時間序列的長度或者說時間步有多少個。 比如I love China。 時間序列的長度爲3, 所以這裏的timesteps就是3.
  • input_length: 這個其實對應timesteps, 表示輸入序列的長度。 當需要在該層後連接Flatten層,然後又要連接Dense層時,需要指定該參數,否則全連接的輸出無法計算出來。

LSTM層接收的輸入, 是(samples, timesteps, input_dim)的3D張量, 輸出的維度, 如果return_sequences=True, 那麼就返回(samples, timesteps, units)的3D張量, 否則就是(samples, units)的2D張量。 這個還是舉個例子吧: 比如我們輸入100個句子, 每個句子有5個單詞, 而每個參數是64維詞向量embedding了。 那麼samples=100, timesteps=5, input_dim=64。
在這裏插入圖片描述
所以, 只要根據規定的輸入去構造自己的數據, 然後就可以進行神經網絡的搭建, 下面也給出一個小demo:

X = Input(shape=[trainx.shape[1], trainx.shape[2], ])
h = LSTM(
            units=10,
            activation='relu',
            kernel_initializer='random_uniform',
            bias_initializer='zeros'
        )(X)

Y = Dense(1)(h)
model = Model(X, Y)

上面這個是最簡單的一層LSTM網絡, 當然也可以搭多層, 多層的話一般前面的層return_sequences爲True, 最後一層return_sequences爲false。

model = Sequential()
model.add(LSTM(128, input_dim=64, input_length=5, return_sequences=True))
model.add(LSTM(256, return_sequences=False))

就是一個這樣的感覺,
在這裏插入圖片描述

6. LSTM的變體GRU

GRU是LSTM網絡的一種效果很好的變體,2014年提出, 它較LSTM網絡的結構更加簡單,而且效果也很好,因此也是當前非常流形的一種網絡。GRU既然是LSTM的變體,因此也是可以解決RNN網絡中的長依賴問題。

首先是看一下GRU內部的一個計算邏輯:

與上面的LSTM相比, 我們會發現這裏成了兩個門, 一個是ztz_t,這個叫做更新門, 看第四個公式就會發現, 這個門其實組合了LSTM中的輸入門和遺忘門,既控制需要從前一時刻的隱藏層ht1h_{t-1}中遺忘多少信息, 也控制加入多少當前時刻隱藏層的記憶信息h~t\tilde h_t 來得到最後的隱藏層狀態hth_t, 一個是rtr_t, 這個叫做重置門, 用來控制當前時刻的記憶更新需要保留多少之前的記憶。
另一個改變是這裏的hidden state和cell state進行了合併, 兩者保持了一致, 成了一個輸出, 而LSTM那裏這倆是不一樣的。下面就看看這些公式到底在幹啥:
h~t=tanh(W[rtht1,xt])ht=(1zt)ht1+zth~t\tilde{h}_{t}=\tanh \left(W \cdot\left[r_{t} * h_{t-1}, x_{t}\right]\right) \\ h_{t}=\left(1-z_{t}\right) * h_{t-1}+z_{t} * \tilde{h}_{t}

前兩個公式是兩個門的計算公式, sigmoid函數, 把這兩個門的輸出控制到了0-1, 看上面h~t\tilde h_{t}的更新, 這個表示的是當前時刻的新信息, 如果rt=1r_t=1, 這個其實就和LSTM的C~t\tilde C_t是一樣的了, 也更容易理解, 就是表示當前時刻的信息, 而加了這麼一個門之後, 就表示我當時時刻的信息更新,可以自由的選擇應該回憶多少前面的信息, 所以變得更加靈活了些, 但依然表示當前時刻的信息。 而後面的隱態更新公式就是在說我當前時刻最後的狀態是有多少取決於前面時刻,多少取決於當前時刻信息, 如果明白了LSTM的話,這個地方估計比較好理解。 比如還是那個例子The cat, which…, was hungry! was時刻的時候, 既然要用到cat那裏的信息, 那在which…這些中途傳播的時候, 直接讓ztz_t幾乎爲0, 那麼就得到了ht=ht1h_t=h_{t-1}, 就可以把cat那裏的狀態傳到was。並且由於這裏時間步之間傳遞的時候,都是一些這樣的等式傳遞, 即使時間步很長, 但也不太容易梯度消失。

關於GRU的太多細節, 這裏就不多說了,很多都是和LSTM類似,畢竟是基於LSTM改變的一個變體, 與LSTM相比,GRU的優勢就是內部少了一個”門控“,參數比LSTM少,因而訓練稍快或需要更少的數據來泛化, 達到的效果往往能和LSTM差不多, 但是GRU不如LSTM靈活, 如果有足夠的數據, LSTM的強大表達能力可能會產生更好的效果。 至於使用, keras裏面也有GRU層可以幫助我們搭建GRU網絡。 核心參數和LSTM的基本一樣, 可以參考LSTM那裏。

6. 總結

這篇文章, 把RNN的兩個常用變體LSTM和GRU整理了一遍, 重點放在了LSTM上, 因爲GRU可以看成一個LSTM的簡化版本,是在LSTM上的改進,有很多思想借鑑了LSTM, 所以LSTM的原理和細節作爲了重點整理。 下面簡單梳理:

首先, 從RNN的梯度消失和不能捕捉長期依賴開始引出了LSTM, 這個結構就是爲了解決RNN的這兩個不足, 然後介紹了RNN的內部細節及計算邏輯, LSTM的關鍵就是引入了可選擇性的記憶單元和三個門控, 使得它變得更加靈活,可以自由的管理自己的記憶, 每一步的隱態更新都會衡量過去的信息與當前信息, 通過門控機制更合理的去更新記憶,然後去更新隱藏狀態。 有了門控, 有了加性機制,也幫助了LSTM減緩梯度消失, 使得反向傳播過程中的連乘現象變得自己可控, 當前時刻的某些參數更新取決於過去哪些時刻讓LSTM自己來選擇。 最後通過numpy實現了一下LSTM的前向傳播過程更好的幫助我們去了解細節,比如各個變量的維度信息。

最後簡單介紹了LSTM的一個變體叫做GRU, GRU在LSTM的基礎上把遺忘門和輸入門進行了合併, 改成了一個更新門, 依賴這一個門就可以自由的選擇當前時刻的信息取決於多少過去,多少當前。 然後還加入了一個重置門, 來控制當前時刻的信息更新有多少依賴於前一時刻的隱態, 增加了一定的靈活性, 並且還把cell 和hidden合併成了一個輸出。 這個結構使得網絡更加容易訓練, 參數較少, 但是表達能力不如LSTM強。

參考

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章