ECCV2018 | PKT_概率知識蒸餾

ECCV2018 | Learning Deep Representations with Probabilistic Knowledge Transfer

https://github.com/passalis/probabilistic_kt

1.傳統知識蒸餾

最早的知識蒸餾方法專門針對分類任務進行設計,它們不能有效地用於其他特徵學習的任務。 在本文中,作者提出了一種通過匹配數據在特徵空間中的概率分佈進行知識蒸餾(PKL)。該方法除了性能超越現有的蒸餾技術外, 還可以克服它們的一些侷限性。包括:(1)可以實現直接轉移不同架構/維度層之間的知識。(2)現有的蒸餾技術通常會忽略教師特徵空間的幾何形狀,因爲它們僅使學生網絡學習教師網絡的輸出結果。而PKL算法能夠有效地將教師模型的特徵空間結構映射到學生的特徵空間中,從而提高學生模型的準確性。PKL算法示意圖如下所示。PKT技術克服了現有蒸餾方法的一些侷限性,通過匹配特徵空間中數據的概率分佈,從而實現知識蒸餾。

2.基於概率的知識蒸餾(PKT)

爲了使得學生模型能夠有效的學習教師模型的概率分佈。作者在訓練網絡的過程中,對每個batch中的數據樣本之間的成對交互進行建模,使得其可以描述相應特徵空間的幾何形狀。利用特徵空間中任意兩個數據點的聯合概率密度,對兩個數據點之間的距離進行概率分佈建模。通過最小化教師模型與學生模型的聯合密度概率估計的差異,實現概率分佈學習。
聯合概率密度函數公式:

從上述公式可以發現,最小化概率分佈並不需要用到標籤數據,因此PKT甚至可以用到無監督學習中。利用上述所說的聯合概率分佈進行知識蒸餾可以避免很多傳統蒸餾方法的缺點。但是,由於實際訓練中我們每個batch都是所有數據的隨機抽樣,使用全局數據是不現實的,基於此作者使用樣本的條件概率分佈代替聯合概率密度函數。
條件概率密度函數公式:

計算當前batch中數據兩兩之間的條件概率密度後,通過最小化教師模型的條件概率分佈和學生模型的條件概率分佈的KL散度,實現概率知識蒸餾。

3.計算概率分佈

如上述所示的條件概率分佈函數公式可知,要求數據間的條件概率分佈需要定義對應的核函數。常見的核函數有高斯核,具體公式如下所示,但由於高斯核中需要定義一個超參數,且該超參數對最終蒸餾結果會參數極大的影響。因此本文並沒有採用這種常見的核函數。

本文嘗試通過餘弦核函數進行條件概率估計。其公式如下所示,根據餘弦函數的定義可以更好的解釋本文提出的PKL蒸餾法體現出的架構和維度無關性。

def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
	# Normalize each vector by its norm
	output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
	output_net = output_net / (output_net_norm + eps)
	output_net[output_net != output_net] = 0

	target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
	target_net = target_net / (target_net_norm + eps)
	target_net[target_net != target_net] = 0

	# Calculate the cosine similarity
	model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
	target_similarity = torch.mm(target_net, target_net.transpose(0, 1))

	# Scale cosine similarity to 0..1
	model_similarity = (model_similarity + 1.0) / 2.0
	target_similarity = (target_similarity + 1.0) / 2.0

	# Transform them into probabilities
	model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
	target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)

	# Calculate the KL-divergence
	loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
	
	return loss

這段代碼就是對上述公式的翻譯,代碼中的output_net代表了當前數據的學生模型輸出特徵圖,而target_net代表了當前數據的教師模型輸出特徵圖。正常情況下該特徵圖維度一般都爲:NCHW。根據上述代碼不論兩者的C的維度是多少,有或者HW的維度是多少,最終經過矩陣轉置相乘,都會變成一個N*N大小的相似性矩陣。通過相似性矩陣經過一系列計算,最終求得兩者的概率分佈,並進行概率學習。

4.結果展示

PKT基於概率的知識蒸餾應用到分類和目標檢測任務中,從下表的結果可以看出該方法的通用和有效性。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章