說到擴展維度,可能第一想法是調用sequeeze()
函數,但是實際上會有更簡單的方式:
x = torch.Tensor(3, 2)
print(x.size())
print(x[:, :, None].size())
print(x[:, None, :, None].size())
打印結果:
torch.Size([3, 2])
torch.Size([3, 2, 1])
torch.Size([3, 1, 2, 1])
- 很簡單,看代碼就能理解,不再贅述。
另外,在PyTorch裏,repeat
和expend
函數的區別在於,對於dim爲1的某個維度,後者不會申請更多的內存。這個機制有利於節省空間。
x = torch.Tensor(3, 1)
print(x.repeat(1, 4).size())
print(x.expand(3, 4).size())
print(x.repeat(1, 4))
print(x.expand(3, 4)) # 等價於 x.expand(-1,4)
print:
torch.Size([3, 4])
torch.Size([3, 4])
tensor([[-7.6648e+06, -7.6648e+06, -7.6648e+06, -7.6648e+06],
[ 3.0844e-41, 3.0844e-41, 3.0844e-41, 3.0844e-41],
[-7.7269e+05, -7.7269e+05, -7.7269e+05, -7.7269e+05]])
tensor([[-7.6648e+06, -7.6648e+06, -7.6648e+06, -7.6648e+06],
[ 3.0844e-41, 3.0844e-41, 3.0844e-41, 3.0844e-41],
[-7.7269e+05, -7.7269e+05, -7.7269e+05, -7.7269e+05]])
實際上,這是numpy.array的一種特性,PyTorch中的Tensor繼承了Numpy中array的很多特性,因此你在PyTorch中也可以這麼用。