最近在做一個分類項目,發現很多“難樣本”比較不好處理,想試試FocalLoss。沒找到pytorch相關實現,就研究起cross_entropy源碼,想手動改一下。
def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean'):
"""This criterion combines `log_softmax` and `nll_loss` in a single
function
(省略部分註釋)
"""
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
可以看到最後一行:
nll_loss(log_softmax(input, 1),***)
nll_loss就是 "negative log likelihood loss"負對數似然損失函數
我們知道,
該函數是把nll_loss和log_softmax函數結合在一起,代替 。
感覺這裏完全可以用softmax函數計算完,再帶入啊,爲什麼要用log_softmax呢?
softmax公式:
這裏可以看到是個指數計算,其實是不穩定的,當很大的時候,會造成數值溢出的情況,那麼我們可以對所有x減去一個C = Max(x)(分子分母同除以)
softmax公式變成:
恆小於0,這樣就不會數值溢出。
問題不就解決了嗎?那爲什麼還是要換成log_softmax呢???
其實仔細看,還是可能造成一種情況,當很大很大,那麼可能很小很小,分子就是無限接近於0的一個負數,那麼在計算交叉熵的時候會出現的情況,得到一個的數,同樣造成數值溢出。
所以我們的pytorch是怎麼計算的呢?
可以看到,在最後的公式中,中,至少有一項是等於0的,所以大於1,就不會出現數值溢出的情況了。