Pytorch 加權交叉熵實現分析

此文不涉及公式。

假設模型對輸入圖像數據(1(batch size) * 1 (channels) * 3 (height) * 3 (width))的分割輸出結果爲1 * 2 (只有兩類,前景,背景)* 3 (height)* 3 (width),ground truth爲 target ( 1 (batch size) * 3 * 3)。

 假設模型的分割結果爲input, ground truth 爲target。

根據文獻[1],計算交叉熵損失有兩中方式,一種是用F.nll_loss();一種是用F.cross_entropy(input, target)。

(1)利用F.cross_entropy()計算未加權的結果:

(2)利用F.cross_entropy()中的weight參數計算加權結果:

對比以上兩個結果可以發現:加權後的交叉熵損失比未加權的交叉熵損失小。(假設加權後的交叉熵損失比未加權的交叉熵損失大,那pytorch的加權交叉熵實現是錯誤的?還需要找一些關於加權加權交叉熵的數學公式)繼續分析,pytorch是如何實現加權交叉熵?

(3)對F.cross_entropy()中的reduce參數設置爲False(關於cross_entropy中參數詳細介紹,請參考文獻[2])。

分析發現,(3)的結果與(1)的結果相同,則reduce=True的作用就是對batch_size * height * width 個像素點的交叉熵損失的均值。

(4)如果按照(3)的理解,對加權的F.cross_entropy的reduce的參數設置False,結果:

分析發現,該結果與(2)結果不一致。那麼這個加權是怎麼實現的?根據文獻【3】, 首先,對batch_size * height * width 個像素點的加權交叉熵損失進行求和得到sum之後;然後,計算出batch_size * height * width 個像素點對應類別權重之和weight_sum,例如:

target中有6個類別爲1的像素點,3個類別爲0的像素點, 這些像素點對應類別權重之和爲:6 * 10 + 3 * 1 = 63。

重新計算損失值:

個人理解:F.cross_entropy()中的weight參數作用:將每個類別的像素點的數量擴大weight倍。

參考文獻:

【1】https://blog.csdn.net/qq_22210253/article/details/85229988

【2】https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss

【3】https://discuss.pytorch.org/t/how-to-use-the-weight-parameter-for-f-cross-entropy-correctly/17786

 代碼:

  https://colab.research.google.com/drive/1VnZ53pwoLB3S0rGb7N-crgErLlxIR_VS#scrollTo=aCvteE5ZAx1Q

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章