pytorch中的expand()和expand_as()函數

pytorch中的expand()和expand_as()函數

1.expand()函數:

     (1)函數功能: 

              expand()函數的功能是用來擴展張量中某維數據的尺寸,它返回輸入張量在某維擴展爲更大尺寸後的張量。

              擴展張量不會分配新的內存,只是在存在的張量上創建一個新的視圖(關於張量的視圖可以參考博文:由淺入深地分析張量),而且原始tensor和處理後的tensor是不共享內存的。

              expand()函數括號中的輸入參數爲指定經過維度尺寸擴展後的張量的size。

     (2)應用舉例:

1)
import torch
a = torch.tensor([1, 2, 3])
c = a.expand(2, 3)
print(a)
print(c)

# 輸出信息:
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3]]



2)
import torch
a = torch.tensor([1, 2, 3])
c = a.expand(3, 3)
print(a)
print(c)

# 輸出信息:
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])


3)
import torch
a = torch.tensor([[1], [2], [3]])
print(a.size())
c = a.expand(3, 3)
print(a)
print(c)

# 輸出信息:
torch.Size([3, 1])
tensor([[1],
        [2],
        [3]])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])


4)
import torch
a = torch.tensor([[1], [2], [3]])
print(a.size())
c = a.expand(3, 4)
print(a)
print(c)

# 輸出信息:
torch.Size([3, 1])
tensor([[1],
        [2],
        [3]])
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]])

     (3)注意事項:

               expand()函數只能將size=1的維度擴展到更大的尺寸,如果擴展其他size()的維度會報錯。 

 

2.expand_as()函數:

     (1)函數功能:

              expand_as()函數與expand()函數類似,功能都是用來擴展張量中某維數據的尺寸,區別是它括號內的輸入參數是另一個張量,作用是將輸入tensor的維度擴展爲與指定tensor相同的size。

     (2)應用舉例:

1)
import torch
a = torch.tensor([[2], [3], [4]])
print(a)
b = torch.tensor([[2, 2], [3, 3], [5, 5]])
print(b.size())
c = a.expand_as(b)
print(c)
print(c.size())

# 輸出信息:
tensor([[2],
        [3],
        [4]])
torch.Size([3, 2])
tensor([[2, 2],
        [3, 3],
        [4, 4]])
torch.Size([3, 2])


2)
import torch
a = torch.tensor([1, 2, 3])
print(a)
b = torch.tensor([[2, 2, 2], [3, 3, 3]])
print(b.size())
c = a.expand_as(b)
print(c)
print(c.size())

# 輸出信息:
tensor([1, 2, 3])
torch.Size([2, 3])
tensor([[1, 2, 3],
        [1, 2, 3]])
torch.Size([2, 3])

 

3.參考資料:

      其他關於張量維度操作的函數參見博文:pytorch張量維度操作

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章