最近在學習網紅模型Universal Transformers時接觸到了自適應計算時間(Adaptive Computation Time, ACT)這個新奇的算法。查了一下,這個算法其實並不算新,2016年就被Google DeepMind組的Alex Graves在論文《Adaptive Computation Time for Recurrent Neural Networks》中提出。但是如果從引用量來說,這篇論文確實還算很新的,截止到發稿時止,該論文的引用量僅爲113次。這個引用量對於Transormers模型的紅火程度來說確實太少了。那爲何ACT算法引用數如此少呢?是太難理解沒有流傳?還是適用範圍太小?抑或是並沒有那麼實用的效果呢?本文就來解讀一下這個ACT算法。
自適應計算時間ACT的用途
上來先劃重點,說說這個方法有什麼用吧。這個方法最直接的用途是控制RNN模型中每一個時刻重複運算的次數。如果不太理解的話,可以看做是控制RNN在每一個時刻狀態計算網絡的深度。擴展來說,ACT算法的思想還可以用到自動控制深度網絡的深度,甚至控制模型的複雜度。
自適應計算時間ACT的原理
問題定義
考慮一個傳統的遞歸神經網絡(recurrent neural network),這個網絡由一個輸入權重矩陣(maxtrix of input weights),一個參數化的狀態轉移模型(parametric state transition model),一組輸出權重(a set of output weights),以及一個輸出偏置量(output bias)組成。當輸入一個序列時,在時間到通過迭代如下公式來計算狀態序列和輸出序列:
在這兒,公式(1)中得到的狀態是一個向量,包含着序列的動態信息。我們可以理解公式(1)中模型將當前的輸入和上一時刻的狀態按照某種關係進行結合,從而產生了當前狀態。當前狀態又經過某一種線性變換,即公式(2),(也可以不變換,如GRU中,此時固定爲全1矩陣,固定爲0)得到了最終的輸出結果。這也是RNN網絡的一個基本思路。根據不同的模型和不同的線性變化方式,就可以得到不同的RNN網絡實例化,如LSTM, GRU,NTM等等。
自適應時間計算(ACT)修改了傳統的遞歸神經網絡,使得對於每一個時刻的輸入可以計算多個狀態,以及多個輸出,此處表示在時刻計算的狀態和輸出個數。需要注意的是,這兒的多個狀態之間存在着序列依賴關係,即的值依賴於。形式上,多個狀態和輸出可由如下公式計算得出:
此處是在時刻的輸入的修正值。根據原文中的說法,是一個二值標誌位用以指示當前狀態的輸入量在狀態計算時是否進行了多步,使得神經網絡可以區分是否是更新的輸入量或者是對於一個輸入量重複的計算。此處我還沒有理解透: (1) 是在第一步之後都保持爲+1還是在第一步之後進行累加得到?(2)爲什麼這兒可以使得神經網絡區分出是否進行了多步計算和對一個輸入量的重複計算?是否會對輸入的值產生影響?對輸入值的範圍是否有要求?放上原文供參考:
where is the input at time augmented with a binary flag that indicates whether the input step has just been incremented, allowing the network to distinguish between repeated inputs and repeated computations for the same input.
這兒值得注意的是,對於某一時刻的多步運算,和不同時刻的運算,採用的模型都是一樣的。同樣,對於某一時刻的多個輸出,和不同時刻的輸出,採用的線性變換也是一樣的。雖然可以採用不同的模型和線性變換來處理同一時刻的多步運算以及不同時刻的運算,但是會顯著增加模型的參數、複雜度、運算時間開銷。值得思考的是,如果運用的模型是一樣的話,會對於結果產生什麼影響呢?性能損失會有多大?
爲何要對某一時刻的輸入重複多次狀態運算
在我看來,對於同一個輸入的重複運算,降低了模型的複雜性和學習難度。從直觀上說這相當於把複雜的問題簡單化,複雜的狀態通過多步運算來模擬。從理論上說,相當於通過多個非線性運算來逼近一個複雜的函數,與深度學習加深網絡深度得到更好表示能力的原理是一致的。那麼,接下來的問題是,需要重複多少次運算(需要多深的網絡)才能讓網絡的學習能力適應數據的複雜性?深度太淺,網絡不能很好的學到數據中所有的複雜關係;網絡太深,又可能存在過擬合或者不好訓練的問題。通常情況下,對於一個深度網絡,大家都會將網絡的深度作爲一個超參數,通過不斷調整這個超參數來找到一個適用於數據複雜性的網絡。但是,這畢竟費時費力,特別是對於大數據的學習來說,通常很難在有限時間內找到一個較優的模型深度。因此,讓模型自適應的確定網絡深度在深度學習中顯得尤爲重要。而本文所介紹的自適應計算時間ACT就是用來解決這個問題的。
如何控制某一時刻重複運算的次數(網絡深度)
自適應計算時間ACT通過在某一時刻的每一次重複運算的輸出中引入一個額外的停止單元(halting unit)來自動控制重複運算的次數。某一次重複運算中的停止單元由這次運算得到的狀態和一個單層的以sigmoid爲激活函數的神經網絡決定,具體公式如下:
而後,這個停止單元被用來計算每一次重複運算的停止概率(halting probability):
其中
被稱爲殘餘量(remainder),定義如下:
是一個很小的常量(原文中設定爲0.01)用於使重複計算可以在第一次計算後就終止(如果)。據此,根據公式(6)的定義,是一個概率分佈函數,滿足並且。這兒可以多說兩句,公式(7)可以理解爲當在第次重複運算之後如果前面各次重複運算累計的停止概率非常大則沒有必要繼續再重複運算了(有極大的概率在第詞重複運算前就終止)。此時,第步的停止概率,不僅僅需要考慮計算出來,還要考慮之後所有可能重複運算的概率之和,這兩部分之和就是公式(8)所計算的殘餘量。
由此我們可以看到,根據公式(6)中的概率公式,我們就可以採樣出所需的神經網絡深度啦。看到這兒,我們發現所謂自適應計算時間ACT算法,本質上就是引入了一個額外的神經網絡來計算每一次重複運算(類比於深度網絡中的每一層)的停止概率,而這兒神經網絡的輸入就是上一次運算的輸出(類別與深度網絡中上一層的輸出)。是不是覺得這兒有點像把注意力機制(Attention mechanism)應用在了深度網絡的層數控制上了(由此聯想可以擴展很多工作哦)。
如何得到自適應計算次數後的輸出結果
雖然上文中構建了重複計算次數的控制模型,但是還沒有提到如何採樣,以及如何將採樣融入到深度網絡的訓練之中。實際上,如果按照上述公式進行隨機採樣運算的話會在網絡訓練時導致非常大的噪音梯度值(noisy gradients),因此ACT方法使用了平均場的思想來得到自適應計算次數後的輸出結果。換而言之,ACT算法並沒有對重複運算次數進行採樣,而是根據重複運算的停止概率對多個重複運算結果進行加權平均,即計算了按照重複運算停止概率採樣運算次數得到結果的期望值。具體公式如下:
當然,這裏面是包含着一個假設:狀態和輸出向量是近似線性的。具體的假設合理性這兒我們就不多闡釋了,原文說得很清楚,有興趣的朋友可以翻看。
對於學到重複運算次數(網絡層數)的約束
上文中討論了自適應計算時間ACT算法的原理,本質上就是額外的引入了控制每一次運算(每一層網絡)停止概率運算的一個神經網絡。那麼,這個額外引入的神經網絡如何學習呢?當然,它可以用最終的任務標籤來訓練,隨着整個網絡一起優化學習。但是,可能會出現一個問題,就是學到的重複運算的次數非常非常大(網絡非常非常深)。這個很好理解,因爲神經網絡也想偷懶,想利用儘可能多的運算來減少每一次運算的難度。這樣的話,和ACT算法的設計初衷就背道而馳了,可能導致模型的過擬合併增加計算開銷。因此,ACT算法中引入了對重複運算次數(網絡層數)的約束。這個約束體現在兩處,第一處就是在目標函數中增加一個約束性使得學到的重複運算次數(網絡層數),第二處是限定一個重複運算次數(網絡層數)的最大值。
就第一處約束來說,ACT算法統計了一個思考序列(ponder sequence),其中時刻的值爲:
由這個思考序列,ACT算法計算得到一個思考損失值(ponder cost):
最後,ACT算法將這個思考損失值和原神經網絡的損失函數相加,得到新的損失函數:
此處,被稱爲時間懲罰參數(time penalty parameter)用以控制對於重複運算次數約束的重要性。原文的實驗顯示,ACT算法的結果對於這個參數非常敏感。這應該也是模型可以改進的方向之一吧。
第二處約束在於ACT算法用一個超參數作爲重複運算次數的上限,從而減低剛開始訓練網絡時學到非常多的重複次數造成的額外計算開銷。使用超參數後,公式(7)被修改爲:
算法的實現
原文中給出了ACT算法的目標函數公式(12)如何進行優化,有興趣的同學可以翻看學習。但是在實際應用中,我們一般採用現有的機器學習框架來完成模型自動的優化求解。ACT算法已經在多個現有流行的深度學習框架平臺實現,包括keras實現,TensorFlow實現,PyTorch實現等等,大家可以直接調用或者參考編碼。