Pytorch 學習(二):Pytorch 梯度簡單操作

Pytorch 中梯度簡單理解

Pytorch 的 tensor 帶有梯度屬性,tensor.grad_fn 存儲梯度信息,利用 backward 函數可進行梯度回傳

import torch
x = torch.randn(2, 2, requires_grad=True)  # requires_gard 打開梯度

同樣可以這樣打開梯度

x = torch.randn(2, 2)
x.requires_gard_(True)

繼續構建計算圖

print(x, x.grad_fn)
y = x * x
z = y + 1
out = z.mean()
out.backward()  # 梯度回傳

backward 函數可以附帶參數 w,默認是一個值爲 1 的標量,由於 out 本身爲標量,因此不影響。

out.backward(torch.tensor(1.0))  # 等價於 out.backward()
out.backward(torch.tensor([1.0]))  # 等價於 out.backward()

當 w 不爲 1 時,backward 的計算帶有了權值,一般 w 具有值用於中間梯度的直接回傳,不經過後續的梯度計算

out2 = out * 2
out2.backward()

等價於

out.backward(torch.tensor(2.0)) # 梯度帶有了額外的權值


此時的回傳不需要通過 out2 計算
當需要回傳的 tensor 不是一個標量時,w 的 size 應與當前 tensor 一致

# z.size() = 2 x 2
w = torch.tensor([10, 1, 0.1, 0.01], dtype=torch.float32) # w.size() = 1 x 4
z.backward(w.view(2, 2)) # 等價於 (z * w.view(2, 2)).backward()

在實際中,有時候還需要凍結部分梯度

x = torch.ones(2, 2, requires_gard=True)
y1 = x ** 2
with torch.no_grad():  # 凍結梯度
  y2 = x ** 3
y3 = y1 + y2

正常情況下,dy3/dx = dy3/dy1 + dy3/dy2 = 2x + 3x^2
此時 dy3/dx = dy3/dy1 = 2x,梯度被凍結

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章