[PyTorch][numpy]擴展維度的一種Trick

說到擴展維度,可能第一想法是調用sequeeze()函數,但是實際上會有更簡單的方式:

x = torch.Tensor(3, 2)
print(x.size())
print(x[:, :, None].size())
print(x[:, None, :, None].size())

打印結果:

torch.Size([3, 2])
torch.Size([3, 2, 1])
torch.Size([3, 1, 2, 1])
  • 很簡單,看代碼就能理解,不再贅述。

另外,在PyTorch裏,repeatexpend函數的區別在於,對於dim爲1的某個維度,後者不會申請更多的內存。這個機制有利於節省空間。

x = torch.Tensor(3, 1)
print(x.repeat(1, 4).size())
print(x.expand(3, 4).size())

print(x.repeat(1, 4))
print(x.expand(3, 4)) # 等價於 x.expand(-1,4)

print:

torch.Size([3, 4])
torch.Size([3, 4])
tensor([[-7.6648e+06, -7.6648e+06, -7.6648e+06, -7.6648e+06],
        [ 3.0844e-41,  3.0844e-41,  3.0844e-41,  3.0844e-41],
        [-7.7269e+05, -7.7269e+05, -7.7269e+05, -7.7269e+05]])
tensor([[-7.6648e+06, -7.6648e+06, -7.6648e+06, -7.6648e+06],
        [ 3.0844e-41,  3.0844e-41,  3.0844e-41,  3.0844e-41],
        [-7.7269e+05, -7.7269e+05, -7.7269e+05, -7.7269e+05]])

實際上,這是numpy.array的一種特性,PyTorch中的Tensor繼承了Numpy中array的很多特性,因此你在PyTorch中也可以這麼用。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章