Back Propagation反向傳播
前言:小案例
我們有這麼一個等式
求:e對a的導數 以及 e對b的導數
如果僅僅從數學的角度計算,這是非常簡單的,但在深度學習中,我們會遇到許多更加複雜的計算,純靠數學解析式來計算是十分困難的,我們需要藉助 Back Propagation(反向傳播)來得到答案
剛剛的等式只是一個非常簡單的舉例,我們要做的是把這個等式理解爲一個計算圖
反向傳播的核心 —> 計算圖
在計算圖中,每一步的計算我們只能進行原子計算(不可分割的) 這裏的原子計算包括加減乘除,矩陣乘法,卷積等等
下圖就是 等式轉化成的計算圖
我們來簡單的理解一下這個圖是如何構建出來的
- a和b通過加法得到c
- b和1通過加法得到d
- c和d通過乘法得到e
上面的3個步驟就是 正向傳播的過程,看上去是不是十分簡單
計算圖構建流程
需要我們注意的是,在正向計算的過程中 ,我們不僅僅簡簡單單的得到 c = a + b = 3 ,與此同時, 我們還得到了c對a和b的偏導
在這裏偏導數存放在a,b兩個變量中,在pytorch中我們可以通過grad參數取得每個變量的偏導是多少(這將在下文中介紹)
如上圖所示,在每一步計算的途中,我們都得到了變量的偏導數,那麼想得到結果就變的十分簡單了
如果想得到e對a的導數,那麼我們僅僅需要把e到a路徑上的偏導值相乘即可 鏈式法則
想要得到e對b的導數,我們只需要把不同路徑的偏導相加即可
案例總結
根據這個簡單的案例,我們實現了一個基於圖的基本算法,我們可以通過它來構建十分複雜的計算圖,計算出我們想要的導數,並且它具有很大的彈性,就算圖改了,只要原子計算沒有改變,我們還可以繼續使用
正文: pytorch編程
案例:根據學習時長 推斷成績
如下圖,前三組數據是我們一直的學習時間x與分數y的關係,我們需要推斷出x=4的時候y是多少
這裏爲了方便效果演示,數據湊的很好,但在實際中,數據會有很多偏差,這就是爲什麼要使用深度學習
上圖是我們一直的信息,依舊是一個非常簡單的案例,我們通過它來了解反向傳播的編程練習
通常情況下我們會假設它是一個線性模型 關於線性模型的概念請看 入口:(線性模型的博客暫未更新)
這裏爲了方便計算
我們就假設
因爲根據肉眼我們顯然能得到w=2,但之前也說了,這只是個舉例,所以我們要模擬出訓練的過程
下圖是根據我們訓練的線性模型得到的應該計算圖
我也簡單的推導了一下 模型函數爲 y=x*w+b 的過程
接下來進行代碼構建
1.加載數據集
import torch
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
2.權重w
這裏假設w爲1
w.requires_grad = True 是爲了能夠反向傳播
w = torch.tensor([1.0]) # 假設 w = 1.0的情況
w.requires_grad = True
3.構建計算圖,定義損失函數
下面兩個函數就達到了上問中計算圖的效果,如下
def forward(x): # y^ = wx
return x * w # w是tensor 所以 這個乘法自動轉換爲tensor數乘 , x被轉化成tensor 這裏構建了一個計算圖
def loss(x, y): # 計算單個的誤差 : 損失
'''
每調用一次loss函數,計算圖自動構建一次
:param x:
:param y:
:return:
'''
y_pred = forward(x)
return (y_pred - y) ** 2
4.訓練
這裏要注意的是,每次訓練的時候,我們要情況w的grad,不然將會進行累加,影響結果
更新w 是通過 在這個步驟中我們使用了w.data = w.data -0.01 * w.grad.data
因爲我們只需要對數據進行更新,如果這裏沒有.data
的話,我們就又構建了一個計算圖
最後輸出結果
eli = []
lli = []
print('predict (before training)', 4, forward(4).item())
for epoch in range(100): # 每輪輸出 w的值和損失 loss
for x, y in zip(x_data, y_data):
l = loss(x, y)
l.backward() # 自動求梯度
print('\tgrad:', x, y, w.grad.item())
w.data = w.data - 0.01 * w.grad.data # 權重的數值更新,純數值的修改 如果不用.data會新建計算圖
# 如果這裏想求平均值 其中的累加操作 要寫成sum += l.item()
w.grad.data.zero_() # 清空權重裏梯度的數據,不然梯度會累加
eli.append(epoch)
lli.append(l.item())
print('progress:', epoch, l.item())
print('Predict (after training)', 4, forward(4).item())
這裏就展示最後幾組,可以看到,經過100輪訓練之後,y已經無限接近於正確的解了
5.繪製圖像達
爲了更加好的觀察
# 繪製函數
plt.plot(eli, lli)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()