深入BBN,如何解決長尾數據分佈的同時兼顧表示學習

1. 問題引入

  • 本次要記錄的論文是,CVPR2020 的 " BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition "。該文旨在解決長尾數據分佈的同時兼顧表示學習。
  • 長尾數據是視覺認知任務如:圖像分類、目標檢測中影響實驗結果的主要問題之一。長尾數據分佈的意思是:數據集中某幾個類別佔據了大部分的數據,而剩餘的類別各自的數據很少。
  • 舉個例子,想用一個1000張圖像的數據集訓練一個分類模型,數據集包含三個類別:人、狗、貓。其中,人900張,狗80張,貓20張。那麼狗、貓就屬於長尾數據了。對於深度學習這種見多識廣的技術來說,見得少了的東西識別起來當然困難,最後訓練出來的分類模型幾乎看到什麼都判別爲人,錯誤率特別高。

2. 相關工作

當然,長尾數據分佈的問題並不是第一次提出來,也有很多方法對其進行了緩解,這些方法統稱爲Re-balancing Strategies,具體如下:

  • Class Re-balancing Strategies。這裏面又分爲兩種方法,Over-Sampling和Under-sampling。
    • Over-Sampling,在訓練的過程中多次採樣數據集中數據量佔比量較小的數據,使得這些數據在訓練時被多次用到,從而緩解長尾數據分佈的問題。
    • Under-Sampling,在訓練中拋棄數據量佔比較高的數據,從而達到各個類別數據量的平衡,以緩解長尾數據分佈的問題。
  • Re-weighting Strategies。在訓練模型的過程中,增加損失中長尾數據的權重,有點類似於Boost的方法。但該方法無法處理實際生活中的數據,一旦長尾數據分佈很嚴重,該方法還容易引起優化的問題。
  • Two-stage Fine-tuning Strategies。這個方法將訓練分爲兩個階段。第一階段像往常一樣正常訓練,第二個階段使用較小的學習率以Re-balancing的方式微調網絡。
  • Mix-up。這是數據增強領域的一種方法,通過在融合多張圖像到同一個圖像中增廣數據集用於訓練;還有的方法在圖像的流形特徵上進行增廣。

3. 進一步探索

那麼,爲什麼Re-balancing Strategies的方法能夠處理長尾數據分佈的問題呢?作者做了實驗,作者假設:

  • 基於深度學習的分類模型分爲2部分,特徵抽取器(feature extractor) 和 分類層(the classifier)。
  • 那麼,分類模型學的過程也可以分爲2部分,表示學習(Representation Learning) 和 分類學習(Classifier Learning)。

作者認爲:Class re-balancing通過改變原始訓練集的分佈,使之接近測試集的分佈,從而使得分類模型將更多的關注放在長尾數據身上,提升了分類的準確性。但作者又認爲由於Class re-balancing改變了原始訓練集的分佈,導致Representation Learning的過程受到破壞,故對圖像中原始特徵的提取也會收到破壞。爲此,作者做了如下實驗,如下圖所示:
在這裏插入圖片描述

  • 作者將訓練分爲2個階段,former stage和latter stage,兩個階段分別採取3種訓練方式,一共訓練出9個模型。上圖中左右兩個矩陣是在兩個數據集上訓練出來模型的精度,我們以左邊的矩陣爲例解釋。
  • 橫座標代表Representation Learning採用的學習策略,縱座標代表Classifier Learning採用的學習策略。CE代表傳統的交叉熵方法訓練,RW即上述的Re-weighting Strategies,而RS代表上述的Re-sampling方法。
  • 當垂直的看矩陣的每一列時,表明該列所有分類模型先用(CE, RW, RS)三個方法之一正常訓練出一個模型,模型訓練完畢之後,固定feature extractor參數不變,重頭開始訓練the Classifier。通過這個方法固定Representation Learning,可以判別出哪種方法訓練出來的分類器具有較好的效果。根據結果,可以看出RS,RW的方法明顯好於CE。
  • 當水平看矩陣的每一行時,表明該行所有分類模型先用(CE, RW, RS)三個方法之一正常訓練出一個模型,模型訓練完畢之後,固定the Classifier參數不變,重頭開始訓練feature extractor。通過這個方法固定the Classifier,可以判別出哪種方法訓練出來的特徵抽取器具有較好的效果。根據結果,可以看出CE的方法具有最低的錯誤率。

最後,作者就可以得到結論:傳統的訓練方式有助於Feature Extractor的學習,而RW、RS有助於分類器的學習。很自然,是否可以結合兩者的優點呢?上述的Two-stage方法就是幹這個事兒的,但是需要2個階段,是否可以設計一個One-stage的端到端網絡實現這一點呢?

4. BBN

作者提出的模型稱之爲BBN,模型由三個部分組成:Conventional Learning Branch,Re-Balancing Branch和Cumulative Learning,具體結構如下圖:
在這裏插入圖片描述

  • Conventional Learning Branch,在這個分支中,每個訓練epoch的數據都是等比例的從原始訓練集中採樣,從而保持原始數據的分佈,有利於模型的Representation Learning。

  • Re-balancing Branch,這個分支旨在緩解長尾數據分佈並提升分類的準確性,這個分支的數據通過Reversed Sampler獲取,後面會詳細解釋。通過這個Reversed Sampler,數據集中類別樣本數量越多的,被採樣的機率越小。

  • 上述兩個分支的模型結構一摸一樣,且共享了除最後一個殘差塊的所有參數。爲什麼要共享參數了?作者給出了兩個原因:

    • 有利於Representation Learning的學習,但並沒有相關的消融實驗。而且一旦這個模型沒有參數共享,那麼它其實就是散裝了兩個網絡而已,不知道這裏的共享是爲了有效而共享,還是爲了共享而共享。當然這並不能說明該模型就不好了,因爲模型的兩點應該在於Cumulative Learning模塊。
    • 降低運算資源,這個很直接,作者在文中表示所有的模型都在一塊1080-ti上完成,我表示很理解。畢竟,作爲同行,我也缺運算資源。
  • 在講Cumulative Learning模塊之前,我們先具體解釋一下Reversed Sampler的構建。

    • 定義NiN_i爲類別ii所包含的樣本個數,NmaxN_{max}爲所有類別中包含最多樣本類別的樣本個數

    • 構建Reversed Sampler有3個子過程

      • 首先,根據樣本個數計算類別ii的採樣概率PiP_i

        Pi=wij=1cwjP_i = \frac{w_i}{\sum_{j=1}^cw_j},其中wi=NmaxNiw_i = \frac{N_{max}}{N_i}

        通過計算wiw_i可以發現,類別包含樣本個數越小,wiw_i值越大。爲什麼要計算PiP_i呢?那肯定是爲了歸一化啦,這樣所有樣本採樣的概率和爲1。

      • 其次,根據PiP_i隨機選擇一個類別

      • 最後,均勻的從被選擇到的類別採集圖像數據獲得mini-batch

  • 最後,我們來深入理解一下Cumulative Learning模塊。其主要目的是爲了:通過權衡兩個分支提取到的特徵的權重來控制分類損失,從而控制模型在學習中的關注點從Representation Learning逐漸轉移到長尾數據分佈問題上。

    • 在模型的前面兩個分支中,分別可得到Conventional Learning Branch的特徵向量fcf_c,以及Re-balancing Branch的特徵向量frf_r。那麼如何去合理的融合這兩個特徵呢?

    • 定義α\alpha爲平衡因子,WcW_c爲Conventional Learning Branch的分類器,WrW_r爲Re-balancing Branch的分類器,那麼兩個分支的特徵可融合爲:

      z=αWcTfc+(1α)WrTfrz = \alpha W_c^Tf_c + (1 - \alpha)W_r^Tf_r

    • 那麼,這個α\alpha如何計算呢?定義整個總訓練epoch個數爲TmaxT_{max},當前訓練的epoch爲TT

      α=1(TTmax)2\alpha = 1 - (\frac{T}{T_{max}})^2

      通過式子可以看出,隨着TT逐漸增大,α\alpha逐漸減少。再結合上面zz的計算,可以發現總的融合特徵zz首先將注意力放在fcf_c特徵,強調了Representation Learning的重要性;隨着epoch的提升,注意力逐漸轉移到frf_r上,模型開始處理長尾數據分佈的問題,從而做到Representation和長尾數據分佈兩者兼顧。

    • 在得到了zz之後,和平常的方法類似,需要先做一個softmaxsoftmax多分類的平滑:

      pi^=ezij=1Cezj\hat{p_i} = \frac{e^{z_i}}{\sum_{j=1}^Ce^{z_j}}

      再通過交叉熵結合兩個特徵向量各自的損失,即兩個特徵向量有各自的Ground Truth標籤,所以損失也分別做並相加

      L=αE(p^,yc)+(1α)E(p^,yr)L = \alpha E(\hat{p}, y_c) + (1 - \alpha) E(\hat{p}, y_r)

至此,本文所提出的模型講解結束,總結一下模型的用意:

  • Conventional Learning Branch正常採樣,保證Representation Learning
  • Re-balancing Branch配備創新點Reverse Sampler,針對長尾數據分佈問題
  • 此外,還設計了α\alpha參數,來自動平衡兩個分支的輸出特徵,旨在先保證Representation Learning,再逐漸將注意力放到長尾數據分佈的身上

5. 相關實驗

實驗配置

  • 作者上來先定義了長尾數據分佈的嚴重程度指標β=NmaxNmin\beta = \frac{N_{max}}{N_{min}},其中NmaxN_{max}指的是所有類別中最大的數據量,而NminN_{min}指的是所有類別中最小的數據量,可想而知β>=1\beta >= 1。文中採用β=10,50,100\beta = 10, 50, 100,增廣CIFAR數據集爲長尾數據分佈的數據集。
  • 本文主要對比了以下幾類方法
    • Baseline,focal loss
    • Two-stage的方法,CE-DRW, CE-DRS
    • state-of-art的在不平衡數據集上取得高分的方法,如LDAM, CB-Focal
  • 在測試的時候,α=0.5\alpha = 0.5,因爲2種特徵都很重要。

實驗一
在這裏插入圖片描述

  • 上述主要三行方法分別爲:MixUp、Two-Stage、State-of-art
  • 可以看出,本文提出的方法均取得了最佳;但還有一個現象,Two-stage也去了較好的方法。這也很容易理解,因爲本文的方法是Two-stage的進一步提升。

實驗二

在這裏插入圖片描述

  • 該實驗爲了證明Reversed Sample的有效性。其中Uniform sample就是傳統的採樣方式,根據原始數據分佈的比例採樣;Balanced Sampler是等概率採樣,類似於Re-balanced方法。
  • 可以看出,本文提出的方法最好,Balanced Sampler的方法比Uniform sampler好也能體現一個隱藏信息:深度學習網絡在大量的數據下,對於頻率出現較高的圖像識別能力早已飽和。所以即使是等概率採樣,減少了高頻圖像出現的概率,又一定程度破壞了Representing Learning,效果都比傳統訓練方法的好。

實驗三
在這裏插入圖片描述

  • 該實驗爲了證明當前α\alpha策略的有效性。爲此,採用了其他兩類α\alpha選擇策略
    • Progress-relevant strategies,即和在訓練過程中修改α\alpha的方法,如:Linear Decay, Cosine Decay
    • Progress-irrelevant strategies,即和訓練過程無關的方法,如:Equal weight, β\beta-distribution中採樣
  • 首先,可以發現在訓練過程相關的動態修改方法明顯優於不相關的。
  • 其次,Parabolic increment方法是效果最差的,這也證明了作者的觀點:應該先將關注點放在Representation Learning,再逐漸遷移到原始數據的分佈。而Parabolic increment是反着來,所以效果是最差的。

實驗四
在這裏插入圖片描述

  • 該實驗主要爲了證明BBN有着接近原始方法訓練的模型的Representation Learning的能力。實驗方法和之前9宮格矩陣的方法類似,首先正常訓練網絡,再固定backbone不變,重頭訓練分類器。可以看見BBN的兩個分支都有着中肯的錯誤率,至少比RW、RS要好。
  • 此外,這也說明了共享策略的有效性。

實驗五
在這裏插入圖片描述

  • 該實驗是爲了證明BBN訓練出來的分類器不存在偏愛性,對所有的類別一視同仁從而解決了長尾數據分佈的問題。橫座標是不同的類別,縱座標是l2norml2-norm可以反映分類器對某個類別的偏愛程度。
  • 可以看到,在所有的折線中,BBN是最爲平坦的。雖然RW、RS也都較爲平坦,但是作者還計算了不同方法不同類別l2norml2-norm的標準差,標準差越小的差別越小,而BBN是最小的。

至此,本文的論文理解分享結束,感謝業界前輩的貢獻,Respect!

寫在後面:一直以來都認爲自己讀論文的姿勢不對,故我開始在博客中記錄中我對論文的理解。但又一直感覺記錄論文博客的寫法也不對,思考良久,我覺得如何去理解論文、如何去記錄對論文的理解,也應該符合一篇論文創作的順序,故我一直在調整撰寫的方式。祝願有一天,我能做到更加輕鬆的閱讀論文,並把自己對論文的理解做到如數家珍的侃侃而談。如果大家有很好的論文閱讀姿勢,歡迎在評論區留言分享~~~

本文爲作者原創,轉載需註明出處!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章