CVPR 2019 | SP_相似性保存知識蒸餾

CVPR 2019 | Similarity-Preserving Knowledge Distillation

1.保持相似性知識蒸餾(SPKD)

在這篇論文中,作者提出了一種新的知識蒸餾形式,該方法是作者觀察到相似語義的輸入往往會使得神經網絡輸出相似的激活模式這一現象啓發得到的。該知識蒸餾方法被稱爲保持相似性知識蒸餾(SPKD),該方法使得教師網絡中相似(不同)激活的輸入樣本對,能夠在學生網絡中產生相同(不同)的激活,從而指導學生網絡的學習。下圖展示了整個學習的過程,從圖中可以看出,在每個batch的訓練中,給定當前batch大小爲b,我們從輸入樣本中計算成對的相似矩陣,最終通過學生網絡和教師網絡之間的相似矩陣進行損失計算,指導學習。該方法整體上和PKT那種基於概率分佈的網絡有些類似,但我個人認爲PKT的方式要優於本文,基於概率分佈的方式更符合網絡訓練的模式。

2.相似性編碼矩陣

本文提出的蒸餾學習方案的前提是,假設,如果兩個輸入在教師網絡中產生高度相似的激活,那麼引導學生網絡走向一個學習方向,該方向也會導致這兩個輸入在學生中產生高度相似的激活。相反,如果兩個輸入在老師身上產生不同的激活,我們也希望這些輸入在學生身上產生不同的激活。如下圖所示:在CIFAR-10數據集中,由經過訓練的WideResNet-16-1和WideResNet-40-2網絡生成的激活相似矩陣G (Eq. 2)。每一列代表一個單獨的批處理,每個軸上按ground truth類分組輸入(批處理大小= 128)。較亮的顏色表示較高的相似值。從下圖中可以證明上述假設,對於同一類的輸入,激活大部分是相似的,而對於跨不同類的輸入,則是不同的。

如下公式和代碼所示,本文利用L2norm的方式來編碼相似矩陣。這進一步說明了本文的編碼策略其實是粗糙的,整體來說本文的思想是好的,但編碼策略略顯不足。


def similarity_loss(self, f_s, f_t):
        bsz = f_s.shape[0]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)

        G_s = torch.mm(f_s, torch.t(f_s))
        # G_s = G_s / G_s.norm(2)
        G_s = torch.nn.functional.normalize(G_s)
        G_t = torch.mm(f_t, torch.t(f_t))
        # G_t = G_t / G_t.norm(2)
        G_t = torch.nn.functional.normalize(G_t)

        G_diff = G_t - G_s
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
        return loss
3.結果及分析

本文的特點:雖然本文的編碼策略有些粗糙。但本文的特點在於,於之前的蒸餾方法鼓勵學生模仿老師的直接輸出不同。因爲它的目的是保存輸入樣本的成對激活相似性,而不是教師的表現空間。在保留相似性知識蒸餾法中,要求學習後的學生空間中很好地保留教師空間中的成對相似性,而不要求學生能夠表達教師的直接輸出。當然這種策略從下方的結果中也可以看出是有效的,PKT算法也是利用類似成對的損失進行蒸餾,只不過將編碼換成了概率分佈的方式~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章