持續學習 (continual learning/ life-long learning)詳解

作者:夕小瑤的賣萌屋—— 「小鹿鹿鹿 」

問題定義

我們人類有能夠將一個任務的知識用到另一個任務上的能力,學習後一個任務時也不會忘記如何做前一個任務。這種能力叫持續學習 (continual learning/ life-long learning) 。而這個能力歸結起來主要有兩個問題:

  • 如何能把之前任務的經驗用上,使得更快更好的學習當前任務;

  • 學習當前任務時,不會忘記之前已經學會的任務。

用更專業的術語來講就是可塑性(學習新知識的能力)和穩定性(舊知識的記憶能力)。

可是,神經網絡不同於人類,由於其自身的設計天然存在災難性遺忘問題。當學習一個新任務的時候,需要更新網絡中的參數,但是上一個任務提取出來的知識也是儲存在這些參數上的呀。於是,神經網絡在學習新任務的時候,舊任務的知識就會被覆蓋。所以如果你試圖教阿法狗去打鬥地主,那麼當它能與小夕一戰的時候,它就再也不是柯潔的對手了。

 

 

用已有的斧頭解決問題

神經網絡算法的災難性遺忘在黑早黑早以前就有研究關注這個問題了,衆所周知(最近發現這個詞超好用誒)的大佬Goodfellow在2013年的時候針對訓練方法(是否使用dropout)和不同的激活函數(logistic sigmoid/ rectified linear/ LWTA/ Maxout)對災難性遺忘問題的影響做了詳盡的實驗分析。

 

那年的實驗

讓神經網絡學習兩個任務(新任務和舊任務),兩個任務的關係有三種:

  • Input reformatting:任務目標一樣,只改變數據輸入格式。類比不同語言的學習,意大利語和西班牙語非常接近,有相似的語法結構。最大的不同是輸入單詞的形式,比如你好,意大利語是buon giorno,西班牙語是Hola。如果神經網絡能學會從buon giorno映射到Hola,就能輕鬆的基於意大利語學習西班牙語啦~~基於這個假設,作者以MNIST數字識別爲例,舊任務是原始的數字分類任務,新任務是將32*32像素打亂的數字分類任務。

  • Similar tasks:任務目標不一致,但是相似。這個就非常好理解啦,符合我們對持續學習最自然的認知。用對不同商品評價的情感分類作爲新舊兩個任務,比如舊任務是對手機評論的情感分類,新任務是對掃地機器人評論的情感分類( •̀ ω •́ )y。

  • Dissimilar tasks:任務目標不相似。設定舊任務爲評論的情感分類,新任務是MNIST數字識別,完全風馬牛不相及的兩個任務。

對每一個task pair,我們有2×4組實驗設置(是否加dropout和四種不同的激活函數)。針對每一個設置,跑25組實驗(隨機初始化超參數),記錄新舊兩個實驗的test error。

 

實驗結果和結論

  • dropout

相信訓練過神經網絡的小夥伴都知道,dropout是一個提高模型準確性和魯棒性的一個利器。dropout的原理非常簡單,就是網絡中有非常多的連接,我們在每一次參數更新的時候,隨機的對這些連接做mask,mask掉的權重參數置零,不參與網絡更新。dropout可以理解成一個簡易的assemble,每次更新一個子網絡,最終的預測結果是所有子網絡預測的均值。

      
      

以相似任務爲例,在8種試驗設置下的25組試驗結果如上圖所示。我們可以得到一條經驗性的結論:Dropout有助於緩解災難性遺忘問題(無論使用哪種激活函數和不同的任務關係)。

 

爲什麼呢?一個比較簡單的理解是dropout強迫網絡把每一層的模式相對均勻的記憶在各個神經元中(不加dropout時容易導致網絡退化,一層中的神經元可能真正起作用的只有幾個)。這樣相當於增加了模型的魯棒性,後續任務對其中的小部分神經元破壞時,不會影響整體的輸出結果,對比之下,如果不加dropout,那麼一旦關鍵的神經元被後續任務破壞,則前面的任務就完全崩了。使用dropout訓練的模型size遠大於不加dropout的模型大小。

      
       

  • 激活函數

但是遺憾的是激活函數的選擇沒有一致的結論。在三種不同的task pair下,激活函數的選擇排序是不同的,大佬建議我們使用cross-validation來選擇網絡中使用的激活函數。(是的,這也可以作爲一個結論🙃)

      

 

 

造新斧頭解決問題

持續學習方法

當前主流的針對神經網絡模型的持續學習方法可以分爲以下五類:

  • Regularization:在網絡參數更新的時候增加限制,使得網絡在學習新任務的時候不影響之前的知識。這類方法中,最典型的算法就是EWC。

  • Ensembling: 當模型學習新任務的時候,增加新的模型(可以是顯式或者隱式的方式),使得多個任務實質還是對應多個模型,最後把多個模型的預測進行整合。增加子模型的方式固然好,但是每多一個新任務就多一個子模型,對學習效率和存儲都是一個很大的挑戰。google發佈的PathNet是一個典型的ensembling算法。

  • Rehearsal:這個方法的idea非常的直觀,我們擔心模型在學習新任務的時候忘了舊任務,那麼可以直接通過不斷複習回顧的方式來解決呀(ง •_•)ง。在模型學習新任務的同時混合原來任務的數據,讓模型能夠學習新任務的同時兼顧的考慮舊任務。不過,這樣做有一個不太好的地方就是我們需要一直保存所有舊任務的數據,並且同一個數據會出現多次重複學習的情況。其中,GeppNet是一個基於rehearsal的經典算法。

  • Dual-memory:這個方法結合了人類記憶的機制,設計了兩個網絡,一個是fast-memory(短時記憶),另一個slow-memory(長時記憶),新學習的知識存儲在fast memory中,fast-memory不斷的將記憶整合transfer到slow-memory中。其中GeppNet+STM是rehearsal和dual-memory相結合的一個算法。

  • Sparse-coding: 災難性遺忘是因爲模型在學習新任務(參數更新)時,把對舊任務影響重大的參數修改了。如果我們在模型訓練的時候,人爲的讓模型參數變得稀疏(把知識存在少數的神經元上),就可以減少新知識記錄對舊知識產生干擾的可能性。Sensitivity-Driven是這類方法的一個經典算法。

這個方法的idea確實是挺合理的,當有效知識儲存在少數的節點上,那麼新知識我們就大概率可以存儲在空的神經元上。

等等!

還記得前面我們說過,dropout是說我們把信息備份在更多的神經元上,當我們在學習新任務的時候就算破壞了其中的幾個也不會影響最終的決策。那麼這兩個推論不就自相矛盾了麼??所以到底應該是稀疏還是稠密,還得通過實驗才能知道呀~~

 

實驗設計

  • 數據集

      

作者使用了三個數據集,MNIST就不用說了,CUB-200(Caltech-UCSD Birds-200)同樣也是一個圖片分類數據集,不過比MNSIT更加複雜,裏面有200類不同種類的鳥類。而AudioSet則是來源於youtube的音頻數據集,它同樣也是一個分類數據集,有632個類。

 

Data Permutation Experiment:和前文設置一樣,新任務就是將舊任務的輸入數據做置換。

Incremental Class Learning:實際中我們總是會有這樣的訴求,就是當模型可以對花🌺進行分類的時候,希望通過持續學習可以認識各種類別的樹🌴。所以這個任務設計就是不斷增加模型可分類類別。以MNSIT數字分類爲例,先讓模型學習一次性學習一半的類別(識別0-5),再逐個增加讓模型能夠識別6-9。

Multi-Modal Learning:多模學習是希望模型能夠實現視覺到聽覺的轉換。作者分別嘗試讓模型先學習圖像分類CUB-200再學習音頻任務AudioSet,以及先學習AudioSet再學習CUB-200。(大概是作者覺得MNSIT任務太簡單,就直接忽略了╮( ̄▽ ̄"")╭)

  • 任務

Data Permutation Experiment:和前文設置一樣,新任務就是將舊任務的輸入數據做置換。

Incremental Class Learning:實際中我們總是會有這樣的訴求,就是當模型可以對花🌺進行分類的時候,希望通過持續學習可以認識各種類別的樹🌴。所以這個任務設計就是不斷增加模型可分類類別。以MNSIT數字分類爲例,先讓模型學習一次性學習一半的類別(識別0-5),再逐個增加讓模型能夠識別6-9。

Multi-Modal Learning:多模學習是希望模型能夠實現視覺到聽覺的轉換。作者分別嘗試讓模型先學習圖像分類CUB-200再學習音頻任務AudioSet,以及先學習AudioSet再學習CUB-200。(大概是作者覺得MNSIT任務太簡單,就直接忽略了╮( ̄▽ ̄"")╭)

  • 評估

          
      

前文提到持續學習的兩個主要問題是學習能力記憶能力,所以作者用Omiga_new來評估模型學習新任務的能力,Omiga_base評估模型的記憶能力,Omiga_all是這種兩個能力的綜合考量。alpha_new,i是模型剛學完任務i對任務i的準確率,alpha_base,i是模型剛學完任務i,對第一個任務(base任務)的準確率,ahpha_all_i是模型剛學完任務i,對學習過的所有任務的準確率。爲了保證不同任務之間的可比性,作者用不使用任何持續學習方法,直接學習base任務的準確率做了歸一。

 

實驗結果和結論

      
       

上表中除了上面提到的五種持續學習方法以外,還有MLP是不加持續學習方法的baseline。

 

在數據輸入格式變換任務下,GeppNet和GeppNet+STM能保持記憶但是喪失了學習新任務的能力;FEL學習新任務能力強,但是不能保持記憶;PathNet和EWC都能一定程度改善災難性遺忘問題,pathnet比ewc要稍好一些~~

       
       

在逐漸增加分類類別實驗下,隨着分別數的增加(新任務的不斷學習),已學習過的分類類別平均準確率在不斷的下降,其中EWC算法完全失效,準確率曲線和baseline重合,而其他持續學習算法,GeppNet 優於 GeepNet+STM 優於 FEL。但是在多模實驗下,EWC卻又是唯一一個方法在兩個順序下都有效的。(剛被打臉又長臉了emmm)

       
       

最終實驗結論!!在實驗了各大主流方法的經典算法之後,發現。。並沒有一個統一的方法可以一致的解決不同場景下的問題。╮(╯▽╰)╭

 

由此可知,解決災難性遺忘問題還是非常任重而道遠的呀。後續小夕和小夥伴們還會更加詳細的介紹文中提到的以及沒有提到的持續學習方法更爲具體的算法細節。希望大家多多支持鴨~~

 

 

 


 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章