不僅搞定“梯度消失”,還讓CNN更具泛化性:港科大開源深度神經網絡訓練新方法

原文鏈接:不僅搞定“梯度消失”,還讓CNN更具泛化性:港科大開源深度神經網絡訓練新方法

paper: https://arxiv.org/abs/2003.10739

code: https://github.com/d-li14/DHM

該文是港科大李鐸、陳啓峯提出的一種優化模型訓練、提升模型泛化性能與模型精度的方法,相比之前Deeply-Supervised Networks方式,所提方法可以進一步提升模型的性能。值得一讀。

Abstract

時間見證了深度神經網絡的深度的迅速提升(自LeNet的5層到ResNet的上千層),但尾端監督的訓練方式仍是當前主流方法。之前有學者提出採用深度監督(Deeply-supervised,DSN)方式緩解深度網絡的訓練難度問題,但是它不可避免的會影響深度網絡的分層特徵表達能力,同時會導致前後矛盾的優化目標。

作者提出一種動態分層模仿機制(Dynamic Hierarchical Mimicking,一種廣義特徵學習機制)加速CNN訓練同時使其具有更強的泛化性能。所提方法部分受DSN啓發,對給定神經網絡的中間特徵進行巧妙的設置邊界分支(side branches)。每個分支可以動態的出現在主分支的特定位置,它不僅可以保留骨幹網絡的特徵表達能力,同時還可以研其通路產生更多樣性的特徵表達。與此同時,作者提出採用概率預測匹配損失進一步提升多分支的多級交互影響,它可以確保優化過程的魯棒性,同時具有更好的泛化性能。

最後作者在分類與實例識別任務上驗證了所提方法的性能,均可取得一致性的性能提升。

Method

該部分內容首先簡單介紹一下深度監督及存在的問題,最後給出所提方法。由於該部分內容公式較多,文字較多,故這裏僅進行粗略的介紹,在後面對進行一些個人理解分析。

Analysis of Deep Supervision

對於深度網絡而言,其優化目標可以描述爲:
argminWmLm(Wm;D)+γR(Wm) argmin_{W_m} \mathcal{L}_m(W_m; \mathcal{D}) + \gamma \mathcal{R}(W_m)
其中Lm(Wm;D)\mathcal{L}_m(W_m; \mathcal{D})表示待優化的整體損失函數,而R(Wm)\mathcal{R}(W_m)表示針對參數添加的一些正則化處理。對於圖像分類而言,上述損失函數可以定義爲:
Lm(Wm;D)=1Ni=1Nfm(Wm;xi)(yi) \mathcal{L}_m(W_m; \mathcal{D})=-\frac{1}{N} \sum_{i=1}^{N} f_m(W_m;x_i)^{(y_i)}
另,由於正則項僅與參數有關,而與網絡結構無關,故在後續介紹中對上述公式進行簡化,得到:
argminWmLm(Wm;D) argmin_{W_m} \mathcal{L}_m(W_m; \mathcal{D})
一般而言,在圖像分類任務中,往往僅在網絡的head後進行損失計算。這種處理方式對於比較淺的網絡而言並沒有什麼問題,但是對於極深網絡而言則會由於梯度反向傳播過程中的“梯度消失”問題導致網絡收斂緩慢或者不收斂或收斂到局部最優。

針對上述現象,Deeply-Supervised Nets提出了多級監督方式進行訓練。該訓練方式的優化目標函數可以描述爲:
argminWm,WsL(Wm;D)+Ls(Wm,Ws;D) argmin_{W_m,\mathcal{W}_s} \mathcal{L}(W_m; \mathcal{D}) + \mathcal{L}_s(W_m, \mathcal{W}_s; \mathcal{D})
其中Ls\mathcal{L}_s表示額外監督信息的損失。注:GoogLeNet一文采用的訓練方式就是它的一種特例。

通過上述上述訓練方式,中間層不僅可以從頂層損失獲取梯度信息,還可以從分支損失獲取提取信息,這使得其具有緩解“梯度消失”,加速網絡收斂的功能。

然而,直接在中間層添加額外的監督信息的方式在訓練極深網絡時可能會導致模型性能下降。衆所周知,深度網絡具有極強的分層特徵表達能力,其特徵會隨網絡深度而變化(底層特徵聚焦邊緣特徵而缺乏語義信息,而高層特徵則聚焦於語義信息)。在底層添加強監督信息會導致深度網絡的上述特徵表達方式被破壞,進而導致模型的性能下降。這從某種程度上解釋了爲何上述監督方式對模型的性能提升比較小(大概在0.5%左右,甚至無提升)。

Dynamic Hierarchical Mimicking

作者重新對上述優化目標進行了分析並給出猜測:“最本質的原因在於損失函數中相加的兩塊損失優化目標不一致”。以分類爲例,儘管兩者均意在優化交叉熵損失,但兩者在中間層的優化方向是不一致的,存在矛盾點,進而導致對最終模型性能產生負面影響。

針對上述問題,作者提出一種新穎的知識匹配損失用於正則化訓練過程,並使得不同損失對中間層的優化目標相一致,從而確保了模型的魯棒性與泛化性能。

image-20200522140134909

所提方法的優化目標函數可以描述如下公式,其示意圖見上圖。
argminWm,WsL(Wm;D)+Ls(WΦ~;IΦ,D)+Lk(WΦ~;IΦ,D) argmin_{W_m, \mathcal{W}_s} \mathcal{L}(W_m;\mathcal{D}) + \mathcal{L}_s(\mathcal{W}_{\tilde{\Phi}};I_{\Phi},\mathcal{D}) + \mathcal{L}_k(\mathcal{W}_{\tilde{\Phi}};I_{\Phi}, \mathcal{D})
其中比較關鍵在於第三項的引入,也就是所提到的知識匹配損失。注:由於全文公式太多,本人只是相對粗略的看來一遍,沒有過於深度去研究。應該不會影響對其的認知,見後續的對比分析。

Experiments

爲驗證所提方法的有效性,作者在多個數據集(Cifar,ImageNet,Market1501等)上的機型了實驗對比分析。

首先,給出了CIFAR-100數據集上所提方法與DSL的性能對比,見下圖。儘管DSL可以提升模型的性能,但提提升比較少,而作者所提DHM可以得到更高的性能提升。該實驗證實了所提方法的有效性。

image-20200522160723944

然後,作者給出了ImageNet數據集上的性能對比,見下圖。可以得到與前面類似的結論,但同時可以看到:對於極深網絡(如ResNe152),DSL的性能提升非常有限,而所提方法仍能極大的提升模型的性能超1%。

image-20200522160942760

其次,作者給出了Market1501數據集上的性能對比,見下圖。結論同前,不再贅述。

image-20200522161222169

最後,作者還提供了其實驗過程中的網絡架構,這裏僅提供一個參考模型(MobileNet)作爲示例以及分析說明。除了MobileNet外,作者還提供了DenseNet、ResNet、WRN等實驗模型。

image-20200522161404152

Discusion

實事求是的說,本人在看到最後的網絡結構和代碼之前是沒看明白這篇論文該怎麼應用的。只是大概瞭解DSL破壞了深度網絡的分層特徵表達能力,針對該問題而提出的解決方案。

看了論文和代碼後,基本上明白了作者是怎麼做的。就一點:既然DSL破壞了深度網絡的分層特徵表達能力,那麼就想辦法去補償以不同損失反向傳播到中間層與底層時優化方向是一致的。那麼該怎麼去補償呢?下圖給出了圖示,中間主幹分支表示預定義好的網絡結構,左右兩個分支表示作者補償的結構,通過這樣的方式可以確保主損失與右分支損失傳播到layer3的優化方向一致,主損失與做分支損失傳播到layer2的優化方向一致。當然圖中兩個顏色layer3表示這是不同的處理過程,分支的處理過程肯定要比主分支的計算量小,否則豈不是加大了訓練難度?

image-20200522162508515

我想,看到這裏大家基本上都明白了DHM這篇論文所要表達的思想了。接下來,將嘗試將其與其他類似的方法進行一下對比分析。首先給出傳統訓練方式、DSL訓練方式與DHM的對比圖(注:圖中暗紅色區域表示損失計算,具體怎麼計算不詳述)。

上圖給出了常規訓練過程、DSL訓練過程以及DHM的訓練成果對比。常規訓練過程僅在head部分有一個損失;而DSN(即DSL)則有多個損失,不同的損失回傳的速度時不一樣的,比如左分支損失直接傳給了layer2,這明顯快於中間的主損失,這是緩解“梯度消失”的原因所在;DHM類似於DSL具有多個損失,但同時爲防止不同損失對中間層優化方向的不一致,而添加了額外的輔助層,用於模擬深度網絡的分層特徵表達。

那麼DHM是如何緩解“梯度消失”現象的呢?個人認爲,它有兩種方式:(1) ResNet與DenseNet中的緩解“梯度消失”的方式,這與網路結構有關;(2)分支層數少於主幹層數,一定程度上緩解了“梯度消失”。

最後,再補上一個與DHM極爲相似的方法DML,兩者的流程圖如下所示。論文原文確實提到了DML方法,但並未與之進行對比。從圖示可以看到兩者還是比較相似的,儘管DML初衷是兩個網絡採用知識蒸餾的方式進行訓練,而DHM則是針對DSL存在的缺陷進行的改進。

image-20200522163733594

私認爲DHM是DML的特例(注:僅僅從上述圖示出發),有這麼三點原因:

  • 損失函數方面,以圖像分類爲例,DML與DHM均採用交叉熵損失+KL散度計算不同分支損失;
  • 分支數方面:儘管DML原文是借鑑識蒸餾方式,但其分支可以不止兩個,比如擴展到三個呢,四個呢?這兩種方式是不是就一樣了呢?
  • 網路結構方面:儘管DML提到的是兩個網絡,但是兩個網絡如果共享stem+layer1+layer2部分呢?從這個角度來看,DHM與DML殊途同歸了。

做完上述記錄後,本人厚着臉皮去騷擾了一下李鐸大神,請教了一下。經允許,現將作者的理解摘錄如下:

DSL存在的問題:(1) 特徵逐級提取問題,如果像上述圖中googlenet/dsn那樣把head直接接在中間層立刻再接classifier,那麼強制要求layer2、layer3、layer4都提取high-level語意特徵,這和一般網絡裏layer2、layer3可能還在提取更low-level的特徵相違背;(2) 不同分支的gradient都會回傳到shared的主支上,如果這些gradient相互衝突甚至抵消,對於整個網絡的優化是產生負面影響的。

DHM的解決方案:(1)第一個問題通過圖中的分支網絡結構的改進來解決;(2)第二個問題則是通過KL散度損失隱式約束梯度來解決。

OK,關於DHM的介紹,全文到底結束!碼字不易,思考更不易,還請給個贊。

Reference

  1. Going Deeper with Convolutions. https://arxiv.org/abs/1409.4842
  2. Deeply Supervised Networks. https://arxiv.org/abs/1409.5185
  3. Deep Mutual Learning. https://arxiv.org/abs/1706.003384

關注極市平臺公衆號(ID:extrememart),獲取計算機視覺前沿資訊/技術乾貨/招聘面經等
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章