在學習cs231n,在做到SVM這一課作業時,被梯度的代碼難住了,再次翻看網上的課程筆記,細緻推導才逐漸清晰。理解能力有限,這裏做個簡單的記錄,防止以後忘記也便於溫習。
https://zhuanlan.zhihu.com/p/20945670?refer=intelligentunit線性分類筆記(中)中詳細介紹了多累支持向量機損失函數,該函數定義如下:
cs231n svm這一課後作業中對於求梯度的代碼如下:
# 估算損失函數的梯度
def svm_loss_naive(W, X, y, reg):
dW = np.zeros(W.shape) # initialize the gradient as zero
# compute the loss and the gradient
num_classes = W.shape[1] # C
num_train = X.shape[0] # N
loss = 0.0
# 對於一個minibatches
for i in xrange(num_train):
scores = X[i].dot(W)
correct_class_score = scores[y[i]]
# 下面這個for循環計算損失函數,loss爲該訓練圖像輸入後的損失,注意這裏的loss有疊加,最後是這個minibatches的loss
for j in xrange(num_classes):
if j == y[i]:
continue
margin = scores[j] - correct_class_score + 1 # note delta = 1
if margin > 0:
loss += margin
dW[:, j] += X[i] # 數據分類錯誤時的梯度
dW[:, y[i]] -= X[i] # 數據分類正確時的梯度
# Right now the loss is a sum over all training examples, but we want it
# to be an average instead so we divide by num_train.
loss /= num_train
dW /= num_train
# Add regularization to the loss.
loss += reg * np.sum(W * W)
dW += reg*np.sum(W) # 這裏正則化項沒有除以num_train是因爲這裏把reg當成 lamda/num_train
return loss, dW
該代碼中需要添加的是 dW[:,j] += X[i], 和 dW[:,y[i]] -= X[i], 這兩行代碼保存了代價函數對W求梯度的值,注意他們所在的條件語句爲if margin > 0, 即對於每個訓練樣本(每幅圖),該代碼很可能執行不止一遍。起初看網上提供的代碼答案並不太懂這樣的做法,後來根據損失函數似乎看出來點什麼,但一直不太明朗。
這裏我們做個假設,假設該分類器有六種分類,分別是 A,B,C,D,E,F,計算得到:
假設正確分類, 那麼上面式子裏的五個數分別對應 j=1,3,4,5,6,(假設j從0開始),也就是說對應的是 W[1]*X,W[3]*X,W[4]*X,W[5]*X,W[6]*X所得到的結果。W爲權值矩陣,具體介紹見課程筆記,其維度爲K*D,K爲分類數目,D爲圖像維度(即像素點數),W[1]表示第1行的向量,維度爲 1*D。
上述Li對W求導數,既然這裏的Li有五個數子相加,我們依次求導相加。
j=1時,max(0,1-2)=0,0的導數肯定爲0;
j=3時,導數只有兩種情況下有值,即 和,其他情況下因爲分子部分沒有對應的W[i],i=1,4,5,6項,因而均不做更新
j=4時,W[4]和W[2]做更新, 和
j=5時,max(0,-2)=0,權值不更新
j=6時,同樣不更新
最後將W[2]的兩個梯度相加就是W[2]的梯度,W[3]和W[4]的梯度也已經求出。上述分析默認△爲0,且是在分類出錯的情況下計算所得。分類正確情況下Li=0,梯度爲0,權值不做更新。
這對應着代碼中以下部分
# 對於一個minibatches
for i in xrange(num_train):
scores = X[i].dot(W)
correct_class_score = scores[y[i]]
for j in xrange(num_classes):
if j == y[i]:
continue
margin = scores[j] - correct_class_score + 1 # note delta = 1
if margin > 0:
loss += margin
dW[:, j] += X[i] # 數據分類錯誤時的梯度
dW[:, y[i]] -= X[i] # 數據分類正確時的梯度
對於每個圖片,分別遍歷分類器輸出向量的每個元素,即對應上文累加和公式的每個項,倘若loss不爲0,則計算該項對應W[i]的偏導和W[yi]的偏導,並在每次循環的時候都做更新。
仍在學習過程中,理解和表述可能都不盡人意,再接再厲,希望能給予幫助。