神經網絡壓縮的方式與實驗

前言

二年級我在參加全國大學生集成電路創新創業大賽的時候,有幸見到了將CNN網絡Deploy到FPGA的設計,在這之後我便一直想完成該設計。寫下本文的時間是2020年4月份,三年級寒假剛開始,我便爲了完成這項工作開始從頭學起Machine Learning的理論基礎,並且在瞭解過一些開源的Verilog-CNN的項目之後,一直存在着一些疑惑,但由於開發FPGA的設備都在校園,所以一直沒有機會實踐證實。

深度學習已被證明在包括圖像分類(Image Classification),目標檢測(Object Detection),自然語言處理(Natural Language Processing)等任務上能夠取得相當不錯的效果。現如今,大量的應用程序都配備了與之相關的深度學習算法,而人們使用這些應用程序甚至無法察覺到這種技術的存在。

目前看到過的將CNN網絡Deploy到FPGA上的優秀項目是:CNN-FPGA

該項目解答了兩個我困惑的問題:

  • 複雜的CNN網絡具有如此多的參數、學習板卡的LUT資源是否足夠

    答案是不夠的,只能實現相當小的網絡,該項目的輸入圖像大小爲28 * 28的RGB圖像,但好在其使用的是標準的CNN架構,還有優化的可能,這個過程我們叫做網絡壓縮(Network Compression)。

  • 如何處理FPGA不擅長的浮點數運算

    答案是將Bias和Weight轉化成有符號的定點數,這個過程叫做網絡量化(Network Quantization),屬於網絡壓縮衆多方法的一類。

這無疑說明了網絡壓縮是完成該設計必不可少的一步,本文主要記錄我在網絡壓縮(Network Compression)的主要四種方案的實踐經歷:

  • Knowledge Distillation:讓小的Model藉由觀察大Model的行爲讓自己學習的更好。
  • Network Pruning:將已經學習好的大Model做裁剪,再訓練裁減後的Model。
  • Weight Quantization:用更好的方案表示Model中的參數,以此來降低運算量和存儲。
  • Design Architecture:將原始的Layer用更小的參數來表現,比如深度可分離卷積的架構。

正文

本次實踐所使用的數據集是李宏毅老師2020機器學習作業三的數據集 food-11,相關程式的編寫也圍繞着作業三展開,即對作業三的模型進行壓縮。

網絡壓縮理論:點擊前往

Kaggle地址:點擊前往

作業說明:點擊前往

原始的網絡結構如下圖:

該圖是一個標準的CNN架構,如果因爲網站圖牀壓縮文件導致圖片失真, 可以點擊這裏查看原圖:

40個epoch之後,訓練集的準確率收斂在0.9左右,驗證集的準確率收斂在0.7左右,表現優秀。

Solutions Flops Param Train Acc Val Acc epochs
Standard CNN 1100.19M 12.833M 0.9 0.7 40

然後本次實驗的目的是使用網絡壓縮的方案,將上文中的網絡進行優化,並觀察優化過後的模型表現如何。

在下面的文章中,我會首先嚐試用Design Architecture的辦法重新設計Model,其次該Model還可以用作Knowledge DistillationNetwork Pruning,最後我們做Weight Quantization

Design Architecture

在本次設計中,我採用了深度可分離卷積的架構,關於其實現原理與理論不做贅述。它的核心思想是將原來的Convolution操作轉化成depthwise(DW)和pointwise(PW)兩個操作。DW是指每一個filter只對一個channel進行卷積,PW是將DW操作後的feature map進行加權組合。總的來說,我們需要將原來的卷積單元進行改造即可,Pytorch中爲了實現卷積的DW操作,爲Conv層提供了group參數。

將原來的網絡進行深度可分離卷積的重構之後,結構如下圖

同樣,你可以點擊這裏查看原圖。

Solutions Flops Param Train Acc Val Acc epochs
Standard CNN 1100.19M 12.833M 0.9 0.7 40
Design Architecture 31.21M 256.78K 0.85 0.6 100

雖然收斂的速度變慢了很多,並且在Training Set 以及Validation Set上的準確度有所下降,但是計算量和參數量都有了顯著的下降。

Knowledge Distillation

也許我們設計的網絡在Validation Set上的表現不夠好是因爲訓練的數據集有一些噪聲。Knowledge Distillation,也叫知識蒸餾能夠解決該問題。核心思想是準備一個pretrained的好Model,然後讓我們的網絡學習該Model的輸出,這樣我們的網絡不僅能學習圖片中是何種食物,也能學習到概率分佈。

首先面對的問題是pretrained的Model從哪裏來,在這裏我踩了很多坑,嘗試使用了從頭開始訓練的resnet18、VGG16、GoogleNet等網絡,在經過很多次的epoch之後驗證集上的準確率仍然只有0.7,最後使用Transfer Learning的方式,移植了resnet34網絡,並且爲了省時間只訓練了全連接層,僅僅5個epoch,就達到了0.86的準確率,最終迭代了20次後在訓練集與驗證集上的準確率都達到了0.9。

Knowledge Distillation實現的方式是將訓練數據預先丟進我們有的pretrained的好Model(以後簡稱TeacherNet)即上文中提到的resnet34網絡,將它得到的輸出送給我們的網絡,不過還需要重新定義我們的Loss函數,既要考慮到StudentNet的輸出,也要考慮到TeacherNet的輸出,在李宏毅老師的投影片裏,Loss的定義如下:
Loss=αT2×KL(Teacher’s LogitsTStudent’s LogitsT)+(1α)(原本的Loss) Loss = \alpha T^2 \times KL(\frac{\text{Teacher's Logits}}{T} || \frac{\text{Student's Logits}}{T}) + (1-\alpha)(\text{原本的Loss})

Solutions Flops Param Train Acc Val Acc epochs
Standard CNN 1100.19M 12.833M 0.9 0.7 40
Design Architecture 31.21M 256.78K 0.85 0.6 100
Knowledge Distillation 31.21M 256.78K 0.82 0.8 80

根據測試的結果可以看到深度可分離卷積架構學習TeacherNet之後,準確率有了很大的提高,並且他們所需要的參數和計算量相同。

Network Pruning

所謂的網絡剪枝,是給我們的網絡瘦身,去除掉一些沒有用的節點。在這裏我使用的是neural pruning的方式,刪除掉不重要的節點。因爲採用Weight pruning的方式,如果直接刪除無用的weight會破壞矩陣導致不能使用GPU進行運算加速,或者將無用的weight置0,這樣並不會實際意義上節省空間。

第一個要解決的問題是,如何衡量節點的重要性。根據李宏毅老師的作業三、有一個簡單的方法:batchnorm layer的𝛾因子來決定neuron的重要性。 (By Paper Network Slimming)

然後被剪枝的網絡是Design Architecture中設計的深度可分離卷積架構,其在驗證集的準確率是0.58左右、將網絡的weight按照rate的比例來剪枝,得到如下結果:

rate train_acc valid_acc epoch
0.9500 0.7110 0.5808 0
0.9500 0.7122 0.5808 1
0.9500 0.7161 0.5828 2
0.9500 0.7121 0.5802 3
0.9500 0.7111 0.5787 4
0.9025 0.6726 0.5545 0
0.9025 0.6713 0.5586 1
0.9025 0.6627 0.5464 2
0.9025 0.6690 0.5516 3
0.9025 0.6745 0.5560 4
0.8574 0.6200 0.5105 0
0.8574 0.6216 0.5117 1
0.8574 0.6199 0.5163 2
0.8574 0.6134 0.5137 3
0.8574 0.6223 0.5015 4
0.8145 0.5771 0.4895 0
0.8145 0.5735 0.4825 1
0.8145 0.5749 0.4825 2
0.8145 0.5781 0.4863 3
0.8145 0.5813 0.4831 4
0.7738 0.5461 0.4685 0
0.7738 0.5446 0.4691 1
0.7738 0.5481 0.4636 2
0.7738 0.5426 0.4633 3
0.7738 0.5478 0.4589 4

在去除掉網絡中20%的節點後、準確率下降了10%。雖然準確率看起來蠻低的,但要是被剪枝的網絡是經過上文知識蒸餾出來的網絡,準確率應該還能提告0.2個百分點。

經過剪枝後的參數量和浮點運算量如下:

Solutions Flops Param Train Acc Val Acc epochs
Standard CNN 1100.19M 12.833M 0.9 0.7 40
Design Architecture 31.21M 256.78K 0.85 0.6 100
Knowledge Distillation 31.21M 256.78K 0.82 0.8 80
Network Pruning 50.80M 171.60K 0.55 0.46 5

可是不知道爲什麼、浮點運算居然變多了!?

Weight Quantization

權重量化的方式有很多種,這裏只嘗試了將原本的float32類型用foat16、或者8bit的數據來進行量化,因爲該方式對於FPGA來說實現更加方便。

對於16bit、Pytorch能很方便的將32bit的float轉化成16bit的float,這裏再給出float32量化到8bit的公式。
在這裏插入圖片描述

Solutions Flops Param Train Acc Val Acc epochs
Standard CNN 1100.19M 12.833M 0.9 0.7 40
Design Architecture 31.21M 256.78K 0.85 0.6 100
Knowledge Distillation 31.21M 256.78K 0.82 0.8 80
Network Pruning 50.80M 171.60K 0.55 0.46 5
Quatization Float16 62.43M 256.78K 0.80 0.59 10
Quatization 8bit 31.21M 256.78K 0.80 0.59 10

保存下來8bit的權重參數文件大小比16bit的文件少佔用了一半的存儲,但是準確率卻沒有下降。

以上,就是本次實驗的內容,爲了完成該實驗,查閱了很多資料,跑了很長時間的數據,但對一些知識理解還是比較淺的,希望以後的自己能修飾修飾。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章