基於生成對抗的結構剪枝——Generative Adversarial Learning

"Towards Optimal Structured CNN Pruning via Generative Adversarial Learning" 這篇文章提出了非常新穎的結構剪枝策略:基於生成對抗的思路,將剪枝網絡設置爲生成器(Generator),其輸出特徵作爲Fake,並設置Soft Mask門控一些異質結構的輸出(如通道、分支、網絡層或模塊等);將預訓練模型設置爲Baseline,Baseline的輸出特徵作爲Real;再引入判別器(Discriminator)與正則化約束,一方面對齊生成器與Baseline的輸出,另一方面驅使生成器中的Soft Mask稀疏化(mask value介於0到1之間),最終達到低精度損失的結構剪枝的目的。基於GAL(Generative Adversarial Learning)的剪枝策略總體如下圖所示:

基於GAL的剪枝策略能夠克服現有結構剪枝技術的不足,不足之處具體表現在:1)相對耗時的多階段優化,迭代執行剪枝與fine-tuning;2)通常採用hard pruning mask,不夠鬆弛、較難優化學習;3)訓練或正則化過程依賴於樣本標註。針對這些不足,基於GAL的剪枝策略,首先通過Baseline與Discriminator的輔助作用,能夠在對抗學習過程中避免樣本標註的使用;其次,Soft Pruning Mask的使用,使得正則化過程變得更加鬆弛、更容易學習收斂;另外,對抗訓練與正則化過程是端到端的、非逐層實施的,並且能夠自動完成最優網絡結構探索、以及類似於知識蒸餾的特徵遷移(Baseline -> Generator)。基於GAL的剪枝策略涉及的符號標記如下,fb(x)與fg(x)分別表示Baseline與Generator輸出的特徵矢量(非Softmax層):

通過Soft Mask(標記爲m)的稀疏化,可以剪除包括通道、分支或Block等在內的基本結構。爲了確保剪枝之後,剪枝模型仍能獲得與Baseline相接近的推理精度,基於GAL的剪枝策略首先對Soft Mask施加L1正則化;其次引入判別器(Discriminator),與剪枝模型(Generator)構成了生成對抗學習,在對抗學習過程中將Baseline輸出的特徵矢量作爲監督信息,用以對齊Baseline與剪枝模型的特徵輸出。在對抗學習與正則化過程中,Baseline的參數固定、不需要更新,而剪枝模型參數WG、Soft Mask以及判別器參數WD需要更新,具體的優化問題如下:

上式中,表示判別器損失,用來引導判別器提升鑑別能力,Baseline的輸出表示Real,而剪枝模型(Generator)的輸出表示Fake,當二者輸出真假難辨時,達到對齊到輸出特徵的目的:

式(1)中數據損失用來進一步對齊Baseline與Generator的輸出特徵,具體表示爲Baseline與Generator輸出特徵之間的MSE損失:

式(1)中正則化損失主要分爲三部分,分別表示對WGmWD的正則化約束:

上式中R(WG)表示一般的weight decay,且通常是L2正則化;R(m)表示對Soft Mask的L1正則化;R(WD)表示對判別器的正則化約束,用以防止判別器主導訓練學習,並且主要採用對抗正則化,促進判別器與生成器之間的對抗競爭:

如果直接採用SGD求解式(1)的優化問題,Soft Mask較難稀疏化(零值較難獲得)。此時通常需要設置一個閾值,並將低於閾值的Mask Value或Scaling Factor置零,達到剪枝的目的,然而剪枝網絡的推理精度會明顯低於Baseline。爲解決該問題,文章引入FISTA方法用以求解式(1)的優化問題,具體如下(i=j=1):

優化策略主要包含兩個交替執行的階段:1)第一個階段固定Gm,通過對抗訓練更新判別器D,損失函數包含對抗損失與對抗正則項;2)第二階段固定D,更新生成器G與Soft Mask,損失函數包含對抗損失中的fg相關項、fb與fg的MSE損失以及Gm的正則項。最終,完成Soft Mask的稀疏化之後,便可以按照門控方式,完成channel、branch或block的規整剪枝。

實驗結果具體見文章的實驗部分,值得注意的是:1)對判別器(Discriminator)施加正則化約束時,對抗正則化(Adversarial Regularization)相比於L1、L2正則化,能夠起到更好的正則化效果,即達到更高稀疏度的同時,網絡的推理精度也更高;2)相比於不使用GAN的特徵遷移學習,GAL能夠起到更好的監督效果,並且GAL是label-free的,能夠更好地激勵Generator輸出與Baseline相接近的特徵。

Paper地址:https://arxiv.org/abs/1903.09291

GitHub地址(PyTorch):https://github.com/ShaohuiLin/GAL

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章