近日研讀了一篇發表在ICLR 2018上的文章:《LEARNING LATENT PERMUTATIONS WITH GUMBEL- SINKHORN NETWORKS》, 其介紹了一種能夠將二維張量以可微分的形式轉變爲轉置矩陣的方法。使得指派、重排等不可微分操作能夠以可微分的形式結合到神經網絡當中。由此,我們便可使BP算法學習這些操作,以實現神經網絡的數字排序、拼圖等算法。
BP之痛
直面評價指標?
其實我在最初使用神經網絡分類時有一個很幼稚的想法,對於最後的分類。能否設計這樣一個損失函數:
最後,我們取所有樣本損失的平均爲最終的loss。這樣我們就可以直接優化最終的指標:準確率,不是很美好嗎?實現見以下代碼:
import torch
x = torch.randn(5, requires_grad=True)
_, predict = torch.max(x, 0)
y = torch.LongTensor([1])
loss = (predict != y).int()
print("x:{}\nidx:{}\nloss:{}\n".format(x, idx, loss))
> x:tensor([-0.7181, -0.2303, -1.4065, 2.0853, -0.9006], requires_grad=True)
> idx:3
> loss:tensor([1], dtype=torch.int32)
不可導!
上面的邏輯粗略來看是沒問題的,但是,有一個很重要的漏洞。我們調用了torch.max
函數,返回了預測結果predict
,然後去和比較計算損失。
但是很遺憾:選取最高概率類別這個操作,即函數是不可導的。我們沒有辦法記錄這一個操作的梯度。也就無法使用BP算法更新網絡(可以看到上方輸出中loss
並沒有記錄到梯度信息).
近似之法
既然上述方法失敗在:$argmax$
這個函數不可導上,那我們能不能進行解決呢,答案自然是可以的。簡單來說,我們可以通過以下可導函數近似argmax
函數(準確來說,是近似onehot(argmax)
函數:
如果需要具體解釋,參考《函數光滑化雜談:不可導函數的可導逼近》。
排列問題
如果我們希望求得一個最優排列,常見的,比如使用匈牙利算法解決最優指派問題,同樣,這個選取最優指派的操作是不可導的,那麼,我們也就不能使用神經網絡去學習這個問題。因此,類比分類問題:我們能不能也使用一個可導的操作去近似選取最優指派這個操作呢,從而使得可以被學習呢?答案是可以的
Sinkhorn operator
我們知道,一個指派,實際上可以等價爲一個置換矩陣,如下所示:
所以,我們能否可微地去近似置換矩陣呢,從而通過學習去學習指派這個操作呢?答案是可以,方法就是Sinkhorn operator
。
給定一個方陣. 我們可以通過以下變幻將其變爲雙線性矩陣。(所謂雙線性矩陣,就是其每一行每一列的和都爲1).
當然,對於指派問題,僅僅是雙線性矩陣還是不夠的,因爲我們要保證$S(x)$
中的元素是非0即1的。而這個限制,我們可以通過增加一個超參數$\tau$
實現:
其中,爲對應收益矩陣爲
的最優置換矩陣,
`
這樣,我們通過神經網絡去將原始數據編碼爲矩陣, 再通過可微操作近似對應的指派。最後就可以實現梯度更新從而訓練網絡了。
下面是一個實現拼圖的示意圖:
實驗
個人使用Pytorch復現了一遍原文給出的數字排序實驗: