CVPR2019 | 關係型知識蒸餾法

CVPR 2019 | Relational Knowledge Distillation
https://github.com/HobbitLong/RepDistiller

1.蒸餾學習

由於大模型的擬合能力強,但計算效率低耗時大,而小模型的擬合能力弱,計算效率高。基於該特徵,蒸餾學習的目的是讓小模型學習大模型的擬合能力,在不改變計算效率的前提下提升小模型的擬合能力。如下圖所示,傳統的蒸餾學習(KD),直接根據小模型和大模型的輸出值進行損失計算,使得小模型的輸出能夠靠近大模型的輸出,以此來模型大模型的擬合能力。但這種方法很顯然存在直觀上的缺點,小模型只能學習大模型的輸出表現,無法真正學習到大模型的結構信息。

傳統的蒸餾學習的損失函數如下,其中ft表示教師模型的輸出,fs表示學生模型的輸出,L表示計算兩者之間的距離。從損失函數中可以直觀的看出,整個蒸餾學習過程中,小模型學習的就是大模型的輸出表現,這種單點學習的方法是粗暴的,不具有結構性的。

2.關係型蒸餾學習

爲了使得小模型能夠更好的學習到大模型的結構信息,本文提出了關係型蒸餾學習法(RKD),如下圖所示,RKD算法的核心是以多個教師模型的輸出爲結構單元,取代傳統蒸餾學習中以單個教師模型輸出爲檢測的方式,利用多輸出組合成結構單元,更能體現出教師模型的結構化特徵,使得學生模型得到更好的指導。

關係型蒸餾學習的損失函數如下,其中t1,t2…tn表示教師模型的多個輸出,s1,s2…sn表示學生模型的多個輸出,L表示計算兩者之間的距離。與傳統的蒸餾學習不同,關係型蒸餾學習的損失函數中還有一個構件結構信息的函數。可以使得學生模型學到教師模型中更加高效的信息表徵能力。本文提出了兩種表徵結構信息的損失:距離蒸餾損失和角度蒸餾損失。

3.距離蒸餾損失(Distance-wise distillation loss)

基於距離的蒸餾損失的公式如下圖所示,本文通過對每個batch中的樣本進行兩兩距離計算,最終形成一個batch*batch大小的關係型結構輸出。最終學生模型通過學習教師模型的結構輸出,實現蒸餾學習。整體的代碼如下所示。

 # RKD distance loss
with torch.no_grad():
    t_d = self.pdist(teacher, squared=False)
    mean_td = t_d[t_d > 0].mean()
    t_d = t_d / mean_td

d = self.pdist(student, squared=False)
mean_d = d[d > 0].mean()
d = d / mean_d
print("d:{},t_d:{}".format(d.size(),t_d.size()))
loss_d = F.smooth_l1_loss(d, t_d)

def pdist(e, squared=False, eps=1e-12):
	e_square = e.pow(2).sum(dim=1)
	prod = e @ e.t()
	res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

	if not squared:
		res = res.sqrt()

	res = res.clone()
	res[range(len(e)), range(len(e))] = 0

	# print("e_square:{}".format(e_square.size()))
	# print("e.t:{},prod:{}".format(e.t().size(),prod.size()))
	# print("unsqueeze(1):{},unsqueeze(0):{}".format(e_square.unsqueeze(1).size(),e_square.unsqueeze(0).size()))
	# print("res:{},len(e):{}".format(res.size(),len(e)))

	return res

4.角度蒸餾損失(Angle-wise distillation loss)

基於角度的蒸餾損失的公式如下圖所示,本文通過對每個batch中的樣本三三樣本,計算兩個角度,最終形成一個batchbatchbatch大小的關係型結構輸出。最終學生模型通過學習教師模型的結構輸出,實現蒸餾學習。整體的代碼如下所示。

# RKD Angle loss
with torch.no_grad():
	td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
	norm_td = F.normalize(td, p=2, dim=2)
	t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
	print("unsqueeze(0):{},unsqueeze(1):{}".format(teacher.unsqueeze(0).size(),teacher.unsqueeze(1).size()))
	print("td:{},norm_td:{},norm_td.transpose(1, 2):{},t_angle:{}".format(td.size(),norm_td.size(),norm_td.transpose(1, 2).size(),t_angle.size()))

sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, dim=2)
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
loss_a = F.smooth_l1_loss(s_angle, t_angle)

5.關係型蒸餾效果

本文提出的關係型蒸餾學習方案在各個公開數據集上都證明了有效性,相較於傳統的蒸餾學習方案,本文通過結構化輸出的監督,獲取了更好的監督學習結果。

RKD_LOSS整體代碼請關注公衆號【CV煉丹猿】,後臺回覆RKD獲取。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章