核心思想
本文提出一種基於半監督訓練的小樣本分類算法。所謂半監督就是在訓練集中即包括帶有標籤的圖片,也包含不帶有標籤的圖片,作者認爲人類在學習物品分類時,也會觀察到許多非目標類別的物體,這種學習方式更加接近實際使用需求,並且可以提高算法的泛化能力。本文以原型網絡(Prototypical Network)作爲baseline,在此基礎上提出了三種改進型,以滿足無監督訓練的需要。與原型網絡(Prototypical Network)相似,訓練集也劃分爲支持集與查詢集,不同的是在支持集中還包含有不帶有標籤的圖片,這些圖片的類別既有支持集中需要訓練的類別,也有原本支持集中並不存在的類別(作爲干擾項),測試集情況基本類似。
下面主要介紹本文提出的原型網絡的三種改進型,關於原型網絡的介紹可以查看之前的文章(點擊此處查看),此處不再重複介紹了。
- 帶有軟k均值的原型網絡
對於無標籤的圖片分類而言,最直接的方法就是k均值(k-means)算法了,本文選擇軟k均值(soft k-means)算法,是因爲硬k均值算法不可微。首先使用正常的原型網絡得到帶有標籤數據集對應的原型作爲聚類中心,然後根據無標籤樣例與聚類中心之間的歐氏距離得到一個局部位置;最後通過將無標籤的樣例整合到各個原型中,得到更新後的原型,計算過程如下:
這個改進可以理解爲先利用有標籤樣例得到原型,再將無標籤的樣例按照就近原則劃分給各個原型,並修正原型的位置。 - 帶有一個干擾項中心和軟k均值的原型網絡
上個方案存在一個重要的問題,就是在無標籤數據集中包含部分干擾項,其類別與有標籤數據集中的類別不同,如果簡單的將其劃分給現有的原型,會導致數據污染。爲了避免這一情況,作者爲干擾性提供了一個初始值爲0的聚類中心,並引入一個長度尺度參數用於表示與聚類中心之間距離的變化,具體計算過程如下:
整合的方式與第一種改進型相同,這一方案相當於爲所有的干擾項提供一個新的聚類中心,以避免其干擾之前的原型。 - 帶有掩碼和軟k均值的原型網絡
第二個方案只把所有的干擾項分爲一類,這顯然太過簡單了,不符合實際情況。作者提出的第三個改進方案並不是將干擾項都劃分到具備高方差多類別的原型裏去,而是採用一種掩碼的方式,使得與原型距離遠的樣本被掩蓋掉,距離近的則更少被掩蓋。首先計算未標記的樣例與原型之間的規範化距離
然後,通過訓練的方式得到兩個參數,分別是軟閾值和斜率,這兩個參數是通過一個多層感知機得到的,輸入值是關於規範化距離的統計學參數
式中skew表示傾斜度,kurt表示峯值係數。最後利用學習到的參數和實現掩碼
式中表示sigmoid函數,因爲sigmoid函數當輸入越小,輸出越接近於0;輸入越大,輸出越接近於1;輸入爲0時,輸出爲0.5。所以當規範化距離超過閾值時,其掩碼值會接近於0(相當於被掩蓋了);當規範化距離小於閾值時,其掩碼值則更接近於1(相當於被保留了)。
實現過程
網絡結構
與原型網絡相同
損失函數
正確分類概率的負對數均值
創新點
- 提出一種半監督訓練的小樣本分類算法,提高算法的泛化能力
- 在原型網絡的基礎上提出三種改進型
算法評價
本文是原型網絡的改進型,也是看到的第一篇採用半監督學習的元學習算法。整體而言還是具有一些有趣的創新點的,尤其是第三種改進型中應用的掩碼方法,是之前沒有見過的。但由於訓練集中包含無標籤樣例,因此在訓練量總數相同的情況下,其準確率還是低於有監督學習的。而且根據作者的實驗結果,也無法確定三種改進型的優劣(在不同測試條件下,表現並不統一),因此對於小樣本分類任務的改進作用有限。
如果大家對於深度學習與計算機視覺領域感興趣,希望獲得更多的知識分享與最新的論文解讀,歡迎關注我的個人公衆號“深視”。