相關文章:
交叉熵三連(1)——信息熵
交叉熵三連(2)——KL散度(相對熵)
交叉熵三連(3)——交叉熵及其使用
在神經網絡中,我們經常使用交叉熵做多分類問題和二分類的損失函數,在通過前面的兩篇文章我們瞭解了信息熵和相對熵(KL散度)的定義和計算方式以及相關的基礎知識。在這篇文章中,會主要總結一下關於交叉熵的內容。
一來因爲剛開始寫博客的緣故,二來所學的知識有限。博客中有很大的篇幅其實是拼湊得來,可能會跟參考資料有大量的重複之處,自己深入的理解相對會少一些,權且是學習的記錄。
如果博客中能有幫助到大家的地方,我很開心,如果大家有什麼疑問也可以和我溝通討論,有錯誤的地方歡迎大家批評指正我儘量第一時間修改,聯繫方式在簽名中給出。
1 交叉熵的定義
交叉熵:表示當基於一個“非自然”分佈Q對真實分佈P進行編碼時,在事件集合中唯一標識一個事件所需要的平均bit數。
對於給定的兩個概率分佈P和概率分佈Q對應的交叉熵的定義如下:
H(P,Q)=EP[−logq]=H(P)+DLKL(P∣∣Q)
其中H(P)是P的熵,DLKL(P∣∣Q)是相對熵,那麼對於離散分佈而言:
H(P,Q)=−x∑P(x)logQ(x)
2 交叉熵的計算
2.1 二分類交叉熵
在二分類問題中,交叉熵的計算方式如下所示:
L=−[y⋅log(p)+(1−y)⋅log(1−p)]
其中:
y:表示樣本 label 的值,如果 label 正類y=1,負類y=0
p:表示樣本預測爲正的概率
如果上述公式沒有辦法直接理解,那麼根據交叉熵計算公式
H(P,Q)=−x∑P(x)logQ(x)
對於二分類問題,上述公式轉化爲如下形式:
H(P,Q)=−x∈{0,1}∑P(x)logQ(x)=−(P(x0)logQ(x0)+P(x1)logQ(x1))
令P(x0)=y,並且Q(x0)=q,那麼P(x1)=1−y,且Q(x1)=1−q
上述公式簡化爲如下所示
L=−[y⋅log(p)+(1−y)⋅log(1−p)]
如果x0表示正樣本的話,y 的值等於正樣本真實的概率等於 label 的值, q 的值表示爲正樣本的預測概率。
2.2 多分類交叉熵
對於多分類問題的交叉熵在二分類問題上擴展後的具體計算公式如下:
L=−c=1∑Myclog(pc)
在上述公式中:
M:表示類別的數量
yc:當前觀測樣本屬於類 c的時候yc=1,否則 yc=0
pc:當前觀測樣本預測到類別c的概率
3 爲什麼是交叉熵
3.1 爲什麼不是分類錯誤率
很多人一開始理解損失函數就是分類錯誤率,我剛開始的時候也是這麼認爲的,分類錯誤率的計算公式如下:
Re=CaCe
在上述公式中:
Re:表示分類錯誤率
Ce:表示分類錯的樣本個數
Ca:表示所有的樣本的個數
在參考的文章中給出了一張樣例表,樣例表中給出了三組民意選舉預測的結果和對應的標籤。
計算結果 |
標籤 |
是否正確 |
0.3 0.3 0.4 |
0 0 1(民主黨) |
yes |
0.3 0.4 0.3 |
0 1 0 (共和黨) |
yes |
0.1 0.2 0.7 |
1 0 0 (其他) |
no |
在三種結果中選民1和選民2的預測以微弱的優勢獲勝,選民3的民意預測結果徹底錯誤,計算得到的分類錯誤率爲:
Re=31
計算結果 |
標籤 |
是否正確 |
0.1 0.2 0.7 |
0 0 1(民主黨) |
yes |
0.1 0.7 0.2 |
0 1 0 (共和黨) |
yes |
0.3 0.4 0.3 |
1 0 0 (其他) |
no |
在模型2中給出了另外一組假設數據,在這組假設數據中選民1和選民2的判斷非常準確,選民3以輕微的概率優勢判錯,分類錯誤率計算爲:
Re=31
在上面給出的模型實例中,雖然錯誤率相等但是從三組樣本最後預測的分類概率看,模型2具有相對明顯的優勢,但是通過分類錯誤率沒有辦法較準確的評估。
接下來我們分析一下使用交叉熵計算得到loss的值,在計算過程中採用ACE(Average cross-entropy error)來計算平均交叉熵,根據多分類問題計算交叉熵計算公式得:
- 模型1
1:−(ln(0.3)∗0+ln(0.3)∗0+ln(0.4)∗1)=−ln(0.4)
2:−(ln(0.3)∗0+ln(0.4)∗1+ln(0.3)∗0)=−ln(0.4)
3:−(ln(0.1)∗1+ln(0.2)∗0+ln(0.7)∗0)=−ln(0.1)
L=3−(ln(0.4)+ln(0.4)+ln(0.1))=1.38
- 模型2:
L=3−(ln(0.7)+ln(0.7)+ln(0.3))=0.64
結論:
- ACE結果準確體現了模型2的效果優於模型1
- cross-entropy 更清晰的描述了真實分佈的數據和預測數據的距離
3.2 爲什麼不是均方誤差(MSE)
接下來我們看看使用均方誤差作爲損失函數是個什麼樣的效果。首先根據3.1節中給出的數據,我們計算均方誤差如下。
-
模型1
1:(0.3−0)2+(0.3−0)2+(0.4−1)2=0.54
2:(0.3−0)2+(0.4−1)2+(0.3−0)2=0.54
3:(0.1−1)2+(0.2−0)2+(0.7−0)2=1.34
L=3(0.54+0.54+1.34)
-
模型2
L=3(0.14+0.14+0.74)
根據MSE的計算結果可知,使用MSE好像也能很好的評估模型1和模型2的效果,爲什麼不用MSE來作爲損失函數去優化整個模型的訓練結果呢?主要原因有兩個:
- 原因1:函數的單調性 。採用MSE做爲損失函數的情況下,損失函數是非凸函數具有很多極值點,容易陷入局部最優解。
- 原因2:計算的簡潔性。 使用均方誤差作爲loss函數的時候求導結果比較複雜運算量會比較大,使用交叉熵計算結果的時候比較簡單,反向誤差的計算比較簡單。
3.2.1 求導過程簡潔性對比
在分類問題中最後計算每一個類的概率的時候,採用softmax計算映射到每一個類的概率,softmax對應的具體公式如下所示。
softmax(x)i=∑jexp(xj)exp(xi)
採用MSE計算loss,輸出的曲線是波動的,損失函數表現的不是凸函數,有很多局部的極值點,這種情況下難以得到最優解,使用交叉熵函數能夠保證在區間內的單調性。
以下我們以二分類爲例分別證明MSE損失函數和交叉熵損失函數的單調性,具體的證明情況如下:
1. 交叉熵損失函數的求導過程
在二分類問題鍾交叉熵損失的計算過程如下。
我們簡單的假設score是線性函數輸入的結果,假設參數w和得分s的計算公式爲:
s=wi⋅xi+bi
根據上述過程,二分類問題正向的L計算分爲三個階段
- score得分的計算
s=wi⋅xi+bi
- sigmoid計算公式
pi=1+esiesi
- 交叉熵損失的計算
L=−[y⋅logpi+(1−y)⋅log(1−pi)]
根據損失函數的計算過程,我們需要反向地計算偏導∂wi∂L,與上面相反的,我們把求導過程也分爲三個部分。
- ∂pi∂L 導數的計算
∂pi∂L=∂pi∂(−[y⋅logpi+(1−y)⋅log(1−pi)])=−piy+1−pi1−y
-
∂si∂pi 導數的計算
pi=σ(si)=1+esiesi
令:
h(x)=1+ex
g(x)=ex
那麼有:
∂si∂p=(h(si)g(si))′=h2(si)h(si)g′(si)−g(si)h′(si)=(1+esi)2(1+esi)(esi)′−esi(1+esi)′=(1+esi)2(1+esi)⋅esi−esi⋅esi=(1+esi)2esi=1+esiesi⋅1+esi1=1+esiesi⋅[1−1+esiesi]=σ(si)⋅(1−σ(si))
-
∂wi∂si 導數的計算
∂wi∂s=xi
最終,我們計算∂wi∂L的結果得到
∂wi∂L=∂pi∂L⋅∂si∂pi⋅∂wi∂si=[−piy+1−pi1−y]⋅σ(si)⋅[1−σ(si)]⋅xi=[−σ(si)y+1−σ(si)1−y]⋅σ(si)⋅[1−σ(si)]⋅xi=[−y+y⋅σ(si)+σ(si)−y⋅σ(si)]⋅xi=[σ(si)−y]⋅xi
2. MSE損失函數的求導過程
在使用MSE作爲損失函數的時候,計算最終的loss值的大小前兩個步驟與使用交叉熵的計算過程相同,第三個階段計算公式如下:
L=(y−pi)2
反向求導過程:
∂pi∂L=−2⋅y+2⋅pi
我們把常數去掉
∂pi∂L=−y+pi
那麼對應的∂wi∂L計算的過程如下:
∂wi∂L=∂pi∂L⋅∂si∂pi⋅∂wi∂si=(−y+pi)⋅σ(si)⋅[1−σ(si)]⋅xi=[−y⋅σ(si)+σ(si)⋅σ(si)]⋅[1−σ(si)]⋅xi
對比權重更新的導數公式,使用交叉熵作爲損失函數最後得到導數公式更簡潔偏於計算。
3.2.2 損失函數的是否爲凸函數
關於損失函數是否爲凸函數的問題,我理解的還不是很深刻,常讀常新。等我有了進一步理解再來增加更新這一部分的內容,或者是新起一篇博客再具體說一下凸函數優化和非凸優化以及是不是凸函數的問題吧!
如果要進一步探究這個問題,大家可以參考以下幾個資料:
【機器學習基礎】交叉熵(cross entropy)損失函數是凸函數嗎?
二元分類爲什麼不能用MSE做爲損失函數?
不理解爲什麼分類問題的代價函數是交叉熵而不是誤差平方,爲什麼邏輯迴歸要配一個sigmod函數?
感覺越看越覺得搞機器學習跟炒股一樣是玄學不成?剛開始還覺得對交叉熵理解清楚了,越看越迷糊多問幾個爲什麼就徹底迷糊了,只能把自己稍微覺得有點道理,或者是認同並且可以理解的羅列出來。知之爲知之,不知爲不知,上下求索吧~
如果有後續,我後續有新的理解我再寫交叉熵後傳吧!!!
歡迎大家討論。
3.3 爲什麼不是相對熵(KL散度)
交叉熵系列的學習和介紹中,我寫了三篇博客分別是介紹了信息熵、KL散度(相對熵)和交叉熵。根據前面介紹的內容KL散度可以用來衡量數據的真實分佈和預測分佈的距離,那麼爲什麼不用KL散度去衡量真實樣本和預測結果之間的差距,作爲損失函數。
這裏我們回顧一下交叉熵的定義:
H(P,Q)=EP[−logq]=H(P)+DLKL(P∣∣Q)
對於H(P,Q)表示用預測分佈Q去編碼真實分佈P所需要的平均bits數。因爲對於模型的訓練過程來說真實分佈通過樣本的數據分佈來估計,所以分佈P是固定的,那麼H(P)的值是一個常量。那麼:
- 在模型訓練的過程中,通過交叉熵和KL散度去評估預測分佈和真實分佈的距離具有相同的效果。
接下來我們再回顧一下KL散度的計算公式和交叉熵的計算公式。
KL散度計算公式
DLKL=i∑P(xi)log(P(xi))−i∑P(xi)log(Q(xi))
交叉熵的計算公式
Ce=−i∑P(xi)log(Q(xi))
根據上述公式對比,使用交叉熵用作損失函數評估真實分佈和預測分佈距離的時候具有明顯的計算上的簡潔性。
4 TensorFlow提供的交叉熵接口
tf.nn.sigmoid_cross_entropy_with_logits
tf.nn.softmax_cross_entropy_with_logits
tf.nn.sparse_softmax_cross_entropy_with_logits
tf.nn.weighted_cross_entropy_with_logits
5 參考資料
【1】維基百科 · 交叉熵
【2】知乎 · 損失函數 - 交叉熵損失函數
【3】神經網絡的分類模型 LOSS 函數爲什麼要用 CROSS ENTROPY
【4】知乎 · 爲什麼用交叉熵做損失函數
【5】github · 03.2-交叉熵損失函數
【6】Tensorflow 中文社區