(本文章適用於 pytorch0.4.0 版本, 既然 Variable 和 Tensor merge 到一塊了, 那就叫 Tensor吧)
在編寫 pytorch 代碼的時候, 如果模型很複雜, 代碼寫的很隨意, 那麼很有可能就會碰到由 inplace operation 導致的問題. 所以本文將對 pytorch 的 inplace operation 做一個簡單的總結。
在 pytorch 中, 有兩種情況不能使用 inplace operation:
- 對於 requires_grad=True 的 葉子張量(leaf tensor) 不能使用 inplace operation
- 對於在 求梯度階段需要用到的張量 不能使用 inplace operation
下面將通過代碼來說明以上兩種情況:
第一種情況: requires_grad=True 的 leaf tensor
import torch
w = torch.FloatTensor(10) # w 是個 leaf tensor
w.requires_grad = True # 將 requires_grad 設置爲 True
w.normal_() # 在執行這句話就會報錯
# 報錯信息爲
# RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
很多人可能會有疑問, 模型的參數就是 requires_grad=true 的 leaf tensor, 那麼模型參數的初始化應該怎麼執行呢? 如果看一下 nn.Module._apply() 的代碼, 這問題就會很清楚了
w.data = w.data.normal() # 可以使用曲線救國的方法來初始化參數
第二種情況: 求梯度階段需要用到的張量
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True
d = torch.matmul(x, w1)
f = torch.matmul(d, w2)
d[:] = 1 # 因爲這句, 代碼報錯了 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
f.backward()
爲什麼呢?
因爲 , 對於 的導數是關於的函數:
- 在計算 的時候, 是等於某個值的, 對於 的導數是和這時候的 值相關的
- 但是計算完 之後, 的值變了, 這就會導致 f.backward() 對於 w2 的導數計算出錯誤, 爲了防止這種錯誤, pytorch 選擇了報錯的形式.
- 造成這個問題的主要原因是因爲 在執行 f = torch.matmul(d, w2) 這句的時候, pytorch 的反向求導機制 保存了 d 的引用爲了之後的 反向求導計算.
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True
d = torch.matmul(x, w1)
d[:] = 1 # 稍微調換一下位置, 就沒有問題了
f = torch.matmul(d, w2)
f.backward()
最後再提一下 .data 與 .detach(), (這部分翻譯自 pytorch0.4.0 的 release note):https://github.com/pytorch/pytorch/releases
在 0.4.0 版本之前, .data 的語義是 獲取 Variable 的 內部 Tensor, 在 0.4.0 版本將 Variable 和 Tensor merge 之後, .data 和之前有類似的 語義, 也是 內部的 Tensor 的概念。
x.data 與 x.detach() 返回的 tensor 有相同的地方, 也有不同的地方:
相同:
- 都和 x 共享同一塊數據
- 都和 x 的 計算曆史無關
- requires_grad = False
不同:
- y=x.data 在某些情況下不安全, 某些情況, 指的就是 上述 inplace operation 的第二種情況
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True
d = torch.matmul(x, w1)
d_ = d.data
f = torch.matmul(d, w2)
d_[:] = 1
f.backward()
# 這段代碼沒有報錯, 但是計算上的確錯了
# 如果 打印 w2.grad 結果看一下的話, 得到 是 1, 但是正確的結果應該是 4.
上述代碼應該報錯, 因爲:
- d_ 和 d 共享同一塊數據,
- 改 d_ 就相當於 改 d 了
但是, 代碼並沒有報錯 , 但是計算上的確錯了
所以, release note 中指出, 如果想要 detach 的效果的話, 還是 detach() 安全一些.
import torch
x = torch.FloatTensor([[1., 2.]])
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True
d = torch.matmul(x, w1)
d_ = d.detach() # 換成 .detach(), 就可以看到 程序報錯了...
f = torch.matmul(d, w2)
d_[:] = 1
f.backward()