起因
今天被PyTroch tensor的requires_grad
搞了一把。具體情況是創建一個tensor和在後續的使用過程中,對requires_grad
的取值會影響到python存儲的變量是否爲leaf node。說起來很抽象,直接上代碼。
(有關leaf node,請參考我的另外一篇博客,https://blog.csdn.net/huyaoyu/article/details/81059315)
測試代碼
以下代碼測試在PyTorch 1.3.1上。
import torch
if __name__ == "__main__":
a = torch.tensor([1.0], requires_grad=False)
print("a.is_leaf = {}. ".format( a.is_leaf ))
b = torch.tensor([1.0], requires_grad=True)
print("b.is_leaf = {}. ".format( b.is_leaf ))
c = torch.tensor([1.0], requires_grad=False).clone()
print("c.is_leaf = {}. ".format( c.is_leaf ))
d = torch.tensor([1.0], requires_grad=False).detach()
print("d.is_leaf = {}. ".format( d.is_leaf ))
e = torch.tensor([1.0], requires_grad=False).cuda()
print("e.is_leaf = {}. ".format( e.is_leaf ))
f = torch.tensor([1.0], requires_grad=True).clone()
print("f.is_leaf = {}. ".format( f.is_leaf ))
g = torch.tensor([1.0], requires_grad=True).detach()
print("g.is_leaf = {}. ".format( g.is_leaf ))
h = torch.tensor([1.0], requires_grad=True).cuda()
print("h.is_leaf = {}. ".format( h.is_leaf ))
i = torch.tensor([1.0], requires_grad=True).clone().detach()
print("i.is_leaf = {}. ".format( i.is_leaf ))
j = torch.tensor([1.0], requires_grad=True).detach().clone()
print("j.is_leaf = {}. ".format( j.is_leaf ))
k = torch.tensor([1.0], requires_grad=True).cuda().detach()
print("k.is_leaf = {}. ".format( k.is_leaf ))
各位猜一下輸出都是什麼?
輸出是這樣的(PyTorch 1.3.1):
a.is_leaf = True.
b.is_leaf = True.
c.is_leaf = True.
d.is_leaf = True.
e.is_leaf = True.
f.is_leaf = False.
g.is_leaf = True.
h.is_leaf = False.
i.is_leaf = True.
j.is_leaf = True.
k.is_leaf = True.
其中f
和h
的輸出顯示對應的python變量不再是leaf node了。其原因在於torch.tensor([1.0], requires_grad=True)
將返回一個設置了requires_grad = True
的tensor,這個tensor的所有後續的.clone()
和.cuda()
操作都是“可微”的,也就是說.clone()
和.cuda()
操作都將返回一個非leaf node。於是如果我們想確保得到的python變量是一個leaf node,最保險的做法是在使用類似於torch的tensor()
或zeros()
函數時,不指定requires_grad
,此時可以對得到的tensor隨意操作.clone()
和.cuda()
並賦值給其他python變量。在得到最終python變量後,通過顯式對requires_grad
成員變量賦值從而設自動梯度運算請求。
參考文獻
https://discuss.pytorch.org/t/how-to-define-a-leaf-tensor-in-pytorch-0-4-1/28461/5