關於 pytorch inplace operation需要注意的問題(data和detach方法的區別)

(本文章適用於 pytorch0.4.0 版本, 既然 Variable 和 Tensor merge 到一塊了, 那就叫 Tensor吧)

在編寫 pytorch 代碼的時候, 如果模型很複雜, 代碼寫的很隨意, 那麼很有可能就會碰到由 inplace operation 導致的問題. 所以本文將對 pytorch 的 inplace operation 做一個簡單的總結。

在 pytorch 中, 有兩種情況不能使用 inplace operation:

  • 對於 requires_grad=True 的 葉子張量(leaf tensor) 不能使用 inplace operation
  • 對於在 求梯度階段需要用到的張量 不能使用 inplace operation

下面將通過代碼來說明以上兩種情況:

第一種情況: requires_grad=True 的 leaf tensor
import torch

w = torch.FloatTensor(10) # w 是個 leaf tensor
w.requires_grad = True    # 將 requires_grad 設置爲 True
w.normal_()               # 在執行這句話就會報錯
# 報錯信息爲
#  RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

很多人可能會有疑問, 模型的參數就是 requires_grad=true 的 leaf tensor, 那麼模型參數的初始化應該怎麼執行呢? 如果看一下 nn.Module._apply() 的代碼, 這問題就會很清楚了

w.data = w.data.normal() # 可以使用曲線救國的方法來初始化參數
第二種情況: 求梯度階段需要用到的張量
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)
f = torch.matmul(d, w2)
d[:] = 1 # 因爲這句, 代碼報錯了 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

f.backward()

爲什麼呢?

因爲f=matmul(d,w2),fw2=g(d)f=matmul(d, w2) , \frac{\partial f}{\partial w2} = g(d) , ff 對於 w2w2的導數是關於dd的函數:

  • 在計算 ff 的時候, dd 是等於某個值的, ff 對於 w2w2 的導數是和這時候的 dd 值相關的
  • 但是計算完 ff 之後, dd 的值變了, 這就會導致 f.backward() 對於 w2 的導數計算出錯誤, 爲了防止這種錯誤, pytorch 選擇了報錯的形式.
  • 造成這個問題的主要原因是因爲 在執行 f = torch.matmul(d, w2) 這句的時候, pytorch 的反向求導機制 保存了 d 的引用爲了之後的 反向求導計算.
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)
d[:] = 1   # 稍微調換一下位置, 就沒有問題了
f = torch.matmul(d, w2)
f.backward()

最後再提一下 .data 與 .detach(), (這部分翻譯自 pytorch0.4.0 的 release note):https://github.com/pytorch/pytorch/releases

在 0.4.0 版本之前, .data 的語義是 獲取 Variable 的 內部 Tensor, 在 0.4.0 版本將 Variable 和 Tensor merge 之後, .data 和之前有類似的 語義, 也是 內部的 Tensor 的概念。
x.data 與 x.detach() 返回的 tensor 有相同的地方, 也有不同的地方:
相同:

  • 都和 x 共享同一塊數據
  • 都和 x 的 計算曆史無關
  • requires_grad = False

不同:

  • y=x.data 在某些情況下不安全, 某些情況, 指的就是 上述 inplace operation 的第二種情況
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)

d_ = d.data

f = torch.matmul(d, w2)
d_[:] = 1

f.backward()

# 這段代碼沒有報錯, 但是計算上的確錯了
# 如果 打印 w2.grad 結果看一下的話, 得到 是 1, 但是正確的結果應該是 4.

上述代碼應該報錯, 因爲:

  • d_ 和 d 共享同一塊數據,
  • 改 d_ 就相當於 改 d 了

但是, 代碼並沒有報錯 , 但是計算上的確錯了

所以, release note 中指出, 如果想要 detach 的效果的話, 還是 detach() 安全一些.

import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)

d_ = d.detach() # 換成 .detach(), 就可以看到 程序報錯了...

f = torch.matmul(d, w2)
d_[:] = 1
f.backward()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章