什麼是softmax
softmax不同於sigmoid函數,用softmax後能夠將有正有負的輸出化爲和爲1的正數輸出,這些輸出相互影響,可以認爲是概率分佈,這就給很多問題的求解提供了便利。softmax函數在神經網絡常常當做多分類神經網絡輸出的激活函數。softmax是將每個神經元輸出通通進行e指數變換,並分別除以這些變換後結果的和,從而得到0~1之間的分值。softmax公式如下:
其中, 是第j個輸出神經元激勵。是softmax層前一層第j個神經元的輸出。
softmax具體過程可以參考下圖,其中的softmax層輸出爲y(j),softmax層輸入爲z(j)。注意所有激勵輸出和爲1
很明顯,y1,y2,y3的和就爲1,是不是很像概率問題。softmax常常用來作爲多分類問題分類器的輸出。
關於softmax的解釋,這裏有些閱讀資源比較有價值:
英文版:http://neuralnetworksanddeeplearning.com/chap3.html
翻譯版:https://hit-scir.gitbooks.io/neural-networks-and-deep-learning-zh_cn/content/chap3/c3s4.html
softmax的損失函數
softmax損失函數如下:
特別注意,一個樣本的輸出損失就是這麼個公式計算得到的,沒有累加和。是正確分類的神經元輸出,比如這要分類動物。樣本1的實際圖片是狗子,但是呢分類器不行輸出的的三個神經元分數分別爲0.1 、0.2、0.7,他們分別代表 小貓、狗子和小豬。顯然小豬分數最高,但是呢我們要盯着狗子神經元的輸出看,納尼才0.2。這個時候這個樣本計算得到的損失就是 -ln(0.2)= 1.609。我們訓練的目的就是通過調整參數來降低這個loss。
有的地方說softmax的損失函數是交叉熵損失,有的地方用log似然函數,實際上在多分類問題中這些公式是等價相通的。我在另一篇博客比較具體的說明了這個問題,傳送門:
https://blog.csdn.net/weixin_39704651/article/details/97392322
softmax反向傳播推導以及代碼實現
在反向傳播中,我們經常用鏈式法則來串聯整個反向傳播的過程從而計算參數的梯度值。所以這裏單獨研究softmax層。我們都知道softmax的公式,根據高中知識也很容易知道它對於輸入的求導是多少。溫習下softmax公式:
我們假設樣本的正確分類在第 i 個神經元。當 時
當 i 不等於 j 時
我們順着這個推導再前進一點點,當 時
當 i 不等於 j 時
是不是非常簡潔,這也是softmax交叉熵(或者說log似然)損失函數結合的一大好處,在反向傳播的時候非常容易計算。筆者做cs231n的作業的時候,被這部分內容饒了很久,下面配合着代碼對softmax做簡要的分析
softmax代碼實現如下;
def softmax(x):
x -= np.max(x)
a = np.exp(x)
b = a / np.sum(a, axis=1, keepdims = True)
return b
因爲在跑代碼的過程中發現經常會出現參數爲NAN的情況,cs231n筆記解釋是因爲指數使得輸出非常大,大數值可能導致計算不穩定,也就是計算機不行,數要小些,效果還要一樣,因此可以用歸一化技巧。即在分式的分子和分母都乘以一個常數C,可以得到數學上等價的公式:
其中的值可自由選擇,不會影響計算結果,通過使用這個技巧可以提高計算中的數值穩定性。通常將設爲。該技巧簡單地說,就是應該將向量中的數值進行平移,使得最大值爲0。也就是上述 x -= np.max(x)
接下里看softmax的loss和反向傳播代碼,這裏給出cs231n作業2第一部分的代碼
def softmax_loss(x, y):
shifted_logits = x - np.max(x, axis=1, keepdims=True) # 確保數值穩定性,輸入減去最大值
Z = np.sum(np.exp(shifted_logits), axis=1, keepdims=True ) # 求softmax公式分母部分
log_probs = shifted_logits - np.log(Z) # 見下文解釋1
loss = -np.sum(log_probs[np.arange(N), y]) / N # 從所有神經元的softmax輸出中找到正確輸出的神經元的分數求loss
probs = np.exp(log_probs) # 見下文解釋2,即softmax輸出分值
N = x.shape[0]
dx = probs.copy()
dx[np.arange(N), y] -= 1 # 見下文解釋3
dx /= N
return loss, dx
解釋1:注意第3行代碼,是一個等價公式
解釋2: 由上面的公式,當然 。
解釋3:上文已經推導了對的偏導數,對於,即對應 dx[np.arange(N), y] -= 1
望與童鞋們一起進步
以上