MetaPruning: 基於元學習的自動化神經網絡通道剪枝

本文出自論文MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning,提出來一個最新的元學習方法,對非常深的神經網絡進行自動化通道剪枝。

本文提出來一個最新的元學習方法,對非常深的神經網絡進行自動化通道剪枝。首先訓練出一個PruningNet,對於給定目標網絡的任何剪枝結構都可以生成權重參數。我們使用一個簡單的隨機結構採樣方法來訓練PruningNet,然後應用一個進化過程來搜索性能好的剪枝網絡。這個搜索方法是非常高效的,因爲權重直接通過訓練好的PruningNet生成,並不需要在搜索時間中進行任何的微調。只需要爲目標網絡訓練處一個簡單的PruningNet,我們可以在不同的人工約束下搜索不同的剪枝網絡。與當前最先進的剪枝方式相比,MetaPruning在MobileNet V1/V2和ResNet上有着最好的性能表現。



一、簡介

  1. 通道剪枝作爲一種神經網絡的壓縮方法被廣泛的實現和應用,它通常包含三個階段:訓練一個大的超參數化網絡,修剪次重要的權重或者通道,最後通過微調或者重訓練剪枝網絡來獲得最終的剪枝網絡。第二個階段通常執行迭代式的逐層剪枝,然後快速微調或者權重重建來重獲得精度。AutoML中利用自動尋找最優結構的特點,基於一個反饋循環或者強化學習,脫離了人工設計的侷限並彌補了剪枝算法依賴數據的不足。
  2. 最近的研究表明通道剪枝的本質是找到好的剪枝結構-逐層的通道數量。由於迭代式尋找最優結構的計算代價很高,因此提出來訓練PruningNet,來對於所有候選的剪枝網絡架構,可以生成權重參數,這樣我們能通過在驗證數據集上評估其精度來搜索到性能好的架構。爲了訓練PruningNet,我們使用一個隨機結構採樣方法,使用相應的網絡編碼向量來生成剪枝網絡的權重,即每一層的通道數量。通過隨機輸入不同的網絡編碼向量,PruningNet逐步學習生成不同剪枝結構的權重參數。在訓練過程結束後,我們通過一個進化搜索方法來搜索到性能好的剪枝網絡,可以靈活地結合到不同的約束例如計算浮點數或者硬件延遲。另外,通過決定每一層或每個階段的通道,可以直接搜索到最好的剪枝網絡,因此可以在shortcut結構中修剪通道。這種神經網絡壓縮方法被稱作MetaPruning。
  3. MetaPruning的兩個階段:(1)訓練一個PruningNet:在每次迭代過程中,一個網絡編碼向量(每層的通道數量)被隨機生成,剪枝網絡也相應地被構建出來,PruningNet將網絡編碼向量作爲輸入,來生成剪枝網絡的權重參數;(2)搜索最佳剪枝網絡:通過不同的網絡編碼向量構建了許多剪枝網絡,並利用剪枝網絡預測的權重對驗證集的優劣進行了評估,在搜索時間內無微調或重訓練過程。
    MetaPruning兩階段
  4. 將此方法應用於MobieNets和ResNet上,在相同的浮點數下,我們的精度比MobileNet V1高2.2%到6.6%,比MobileNet V2高0.7%到3.7%,比ResNet-50高0.6%到1.4%。在相同的延遲下,比MobileNet V1高2.1%到9.0%,比MobileNet V2高1.2%到9.9%。
  5. 本文的主要貢獻點:(1)提出一個元學習方法MetaPruning來用於通道剪枝,其中心思想是學習一個元網絡來生成不同剪枝結構的權重參數;(2)節省了超參數調優中的人力過程,允許使用所需要的度量標準來直接優化;(3)可以很容易地在搜索所需結構時實施約束,而不需要手動調整強化學習的參數;(4)可以不費力地修剪像ResNet結構這樣的short-cuts的通道。

二、相關工作

  1. Pruning:網絡剪枝對於深度網絡的冗餘度去除是一個普遍的方法。在權重修剪過程中,通常會剪去單個權重來壓縮模型大小,但同時會導致非結構化的稀疏過濾器。傳統的通道剪枝方法是根據每個通道的重要程度,以迭代方式修剪通道,或者添加一個數據驅動的稀疏度。
  2. AutoML:該方法將多設備上的實時推理延遲考慮在內,通過強化學習或者一個自動化的反饋循環在一個網絡的不同層上迭代式修剪通道。與先前的AutoML剪枝方法相比,MetaPruning方法在精度滿足約束條件方面具有較高的靈活性,並具有對short-cut中的通道進行修剪的能力。
  3. Meta Learning:它指代着學習觀察不同的機器學習方法如何在不同的學習任務上執行。在本文中我們使用meta learning來進行權重預測,權重預測表示一個神經網絡的權重被另一個神經網絡所預測,而不是直接學習得到。
  4. Neural Architecture Search:使用強化學習、遺傳算法或者基於梯度的方法找到最優的網絡結構和超參數。通過與drop-path聯合訓練多項選擇,它可以在訓練過的網絡中搜索到最高精度的路徑。調整通道寬度也包含在一些神經架構搜索方法中。我們所提出的針對通道剪枝的MetaPruning方法能夠通過訓練PruningNet進行權重預測來解決這一連續的通道剪枝挑戰問題。

三、方法

  1. 我們將通道剪枝問題用公式表示爲:(c1,c2,...cl)=argminc1,c2,...clL(A(c1,c2,...cl;w))C<constraint(c_1,c_2,...c_l)^*=\mathop{\arg\min}\limits_{c_1,c_2,...c_l}{L}(A(c_1,c_2,...c_l;w))\quad C<constraint, 其中A是剪枝前的網絡,我們嘗試找到剪枝網絡的通道寬度(從第一層到第L層),在權重被訓練後有着最小的損失,同時C滿足所規定的的約束(FLOPs或者延遲)。爲此,我們提出構建一個PruningNet,一種元網絡,可以通過在驗證集上的評估快速獲得所有可能剪枝網絡結構的優劣度。然後我們可以應用任何搜索方法(比如進化算法)來搜索到最佳的剪枝網絡。
  2. PruningNet訓練:PruningNet是一個元網絡,將一個網絡編碼向量(c1,c2,...clc_1,c_2,...c_l)作爲輸入,然後輸出剪枝網絡的權重,可表示爲:W=PruningNet(c1,c2,...cl).W=PruningNet(c_1,c_2,...c_l). 一個PruningNet block由兩個全連接層組成,在前向傳遞過程中,PruningNet將網絡編碼向量作爲輸入,然後生成權重矩陣。與此同時,一個剪枝網絡被構造出,其每一層的輸出通道寬度等同於網絡編碼向量中的元素。生成的權重矩陣被裁剪來匹配剪枝網絡中輸入輸出通道的數量。在後向傳遞過程中,並沒有更新剪枝網絡的權重,而是計算PruningNet裏權重的梯度。爲了訓練PruningNet,我們提出了隨機結構採樣,在訓練階段的每次迭代過程中,網絡編碼向量被隨機生成來選擇每層的通道數量。有着不同的網絡編碼,不同的剪枝網絡被構建出來,相應的權重由PruningNet來提供。通過使用不同的編碼向量隨機訓練,PruningNet學習預測各種不同剪枝網絡的合理權重。

PruningNet隨機訓練方法網絡架構以及reshape操作
5. 剪枝網絡搜索:在PruningNet訓練完後,我們可以通過輸入網絡編碼到PruningNet中,生成相應的權重和在驗證集上進行評估工作,來獲取每個可能剪枝網絡的精度。由於網絡編碼向量數量巨大的問題,爲了在約束條件下找到高精度的剪枝網絡,我們使用一個進化搜索,可以很容易地合併軟硬性約束。每個剪枝網絡被編碼成一個包含每層通道數量的向量,被命名爲剪枝網絡的基因。我們首先隨機選擇大量的基因,通過做評估來獲得相應剪枝網絡的精度。然後前K個最高精度的基因被挑選出來,使用交叉和變異方法生成新的基因。變異操作通過隨機改變基因中的元素比例來執行,交叉操作通過隨機重組兩個親本基因的基因來產生後代。通過迭代進行這個過程,我們可以獲得滿足約束條件的基因,同時得到最高精度。
進化搜索算法

四、實驗結果

  1. MetaPruning on MobileNets and ResNet:對於沒有short-cut結構的網絡MobileNet V1,我們裁剪原始權重矩陣的左上邊,來匹配輸入和輸出通道。在MobileNet V2中,每個階段都從匹配兩個階段之間的維度瓶頸塊開始。爲了修剪包含shortcut的結構,我們生成兩個網絡編碼向量,一個對總體階段的輸出通道進行編碼來匹配shortcut裏的通道,另一個對每個block的中間通道進行編碼。在PruningNet中,我們首先將網絡編碼向量解碼爲每個塊的輸入輸出和中間通道壓縮比,然後我們生成那個block塊中的相應權重矩陣。ResNet和MobileNet V2的構建過程相同。在這裏插入圖片描述
  2. FLOPs約束下的剪枝效果比較:使用MetaPruning學習得到的剪枝方法,與0.25x下的MobileNet V1相比我們獲得了6.6%的精度提升,同樣與MobileNet V2和ResNet相比均獲得了很好的提升。在與最先進的AutoML剪枝方法相比中,MetaPruning獲得了較好的效果,它還消除了人工調整強化學習超參數的劣勢。
    FLOPs約束下的性能比較
  3. 延遲約束下的剪枝效果比較:在合理的假設下,每一層的執行時間是獨立的,我們可以通過將網絡中所有層的運行時間相加來得到網絡延遲。通過估計在目標設備上執行不同輸入和輸出通道寬度的卷積層的延遲,我們首先構建一個look-up表。然後我們可以從這個look-up表中計算得到構建網絡的延遲。在相同延遲下利用MetaPruning得到的剪枝網絡可以獲得顯著的更高精度。延遲約束下的性能比較
  4. 剪枝網絡可視化:(1)當向下採樣以步長爲2的深度卷積方式進行時,需要使用更多的通道數量來攜帶信息,因此MetaPruning自動學會在下采樣過程中保存更多的通道。(2)MetaPruning方法自動學會在靠後的階段中修剪較少的shortcut通道數。可視化效果
  5. 權重預測的效果:元學習中的權重預測機制對不同剪枝結構的權重進行了有效的去相關處理,從而使PruningNet獲得更高的精度。權重預測效果展示

五、結論

本文我們提出了MetaPruning用作通道剪枝,其具有以下優點:(1)與統一剪枝基線以及最先進的通道剪枝方法(傳統的和AutoML)相比,其具有更高的精度;(2)在不引入額外超參數的情況下,它可以靈活地針對不同的約束條件進行優化;(3)像ResNet這樣的架構可以有效地被處理;(4)整個過程是非常高效的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章