CVPR 2019 | VID_最大化互信息知識蒸餾

CVPR 2019 | Variational Information Distillation for Knowledge Transfer
https://github.com/qiu931110/RepDistiller

1.互信息

在這篇論文中,作者提出了一種新的知識蒸餾形式,該方法將知識蒸餾的最優性能定義爲最大化教師和學生網絡之間的互信息。那麼爲什麼通過最大化互信息可以使得蒸餾學習變得有效呢?首先作者對互信息做了如下定義:

如上述公式所述,互信息爲[教師模型的熵值] - [已知學生模型的條件下的教師模型熵值]。而我們又有如下常識:當學生模型已知,能夠使得教師模型的熵很小,這說明學生模型以及獲得了能夠恢復教師模型所需要的“壓縮”知識,間接說明了此時學生模型已經學習的很好了。而這種情況下也就是說明上述公式中的H(t|s)很小,從而使得互信息I(t;s)會很大。作者從這個角度解釋了爲什麼可以通過最大化互信息的方式來進行蒸餾學習。

2.蒸餾過程詳解

如下圖所示,由於p(t|s)難以計算,作者根據文獻The IM algorithm: a variational approach to information maximization. 2004.提出的IM算法,利用一個可變高斯q(t|s)來模擬p(t|s),下述公式中的大於等於操作用到了KL散度的非負性。由於蒸餾過程中H(t)和需要學習的學生模型參數無關,因此最大化互信息就轉換爲最大化可變高斯分佈的問題。

如下公式所示,作者利用一個均值,方差可學習的高斯分佈來模擬上述的q(t|s)。

如下代碼所示,作者通過一個卷積小網絡來模擬可變均值,並加上relu操作增強可變均值的非線性能力。

self.regressor = nn.Sequential(
            conv1x1(num_input_channels, num_mid_channel),
            nn.ReLU(),
            conv1x1(num_mid_channel, num_mid_channel),
            nn.ReLU(),
            conv1x1(num_mid_channel, num_target_channels),
                            )
pred_mean = self.regressor(input)

並利用如下公式構建可學習的方差,其中阿爾法c是可學習參數。

self.log_scale = torch.nn.Parameter(
            np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
            )
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps

最終整個蒸餾過程如下圖所示,學生網絡除了學習自身任務的交叉熵損失外,同時與教師網絡保持高互信息(MI),通過學習並估計教師網絡中的分佈,激發知識的傳遞,使相互信息最大化。

3.結果展示

在這篇文章中,作者提出了通過最大化兩個神經網絡之間相互信息的變分下界來實現有效知識轉移的VID框架。算法是基於高斯觀測模型實現的,如下結果表明,在蒸餾學習方面,算法性能優於其他基準。實話說這個算法的數學性太強了!雖然讀了兩遍,也把代碼復現到業務中了,但對內部的細節還是沒有摸得太透,後續需要把IM算法精度一遍,纔有可能真正理解變分分佈的概念。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章