關於 Numpy 以及 Pytorch 的數組shape的一點總結


不知道大家有沒有類似的問題,處理數據的時候很多時候會被各種數組的 shape 的變化搞暈,但是這方面的資料又不太好找,這裏記錄一點我遇到的這方面的知識點。

1.數組中數據存儲的結構

首先從數據的排列來看,從一個簡單的例子來看:

a = np.linspace(1, 24, 24).reshape(2, 3, 4)
print(a)

# 輸出結果:
[[[ 1.  2.  3.  4.]
  [ 5.  6.  7.  8.]
  [ 9. 10. 11. 12.]]

 [[13. 14. 15. 16.]
  [17. 18. 19. 20.]
  [21. 22. 23. 24.]]]

首先 1 - 24 的24個數字,排列成 shape 爲 (2, 3, 4) 的數組,可以看到這種數據排列的方式顯然是按照從最低維(這裏是一個三維數組,0,1,2)2開始排列數據,這樣至少確定了在進行reshape的時候數據是如何填充的,我的猜想是reshape可以看作是兩步:

  • 1.首先將數組整個拉平(既然填充的時候是現在高維度上進行填充,那麼拉平的過程就是反向的了)
  • 然後按照新的 shape 按照上面的方式進行填充

2.數組的座標問題

舉個例子來看非常明顯:

a = np.linspace(1, 24, 24).reshape(2, 3, 4)

print(a[1][0][0])
print(a[0][1][0])
print(a[0][0][1])

# 可以看到輸出分別爲:
13.0
5.0
2.0

也就是最低維是行(向下增加),第二維是列,第三維是深。

3.對於Pytorch 的shape相關問題

Numpy 中的數組 array 也就是對應 Pytorch 的 tensor ,那麼上面的方式應該在 pytorch 中也是適用的,測試:

b = torch.from_numpy(a)
print(b)
print(b[1][0][0])
print(b[0][1][0])
print(b[0][0][1])

# 輸出爲:
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],

        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]], dtype=torch.float64)
tensor(13., dtype=torch.float64)
tensor(5., dtype=torch.float64)
tensor(2., dtype=torch.float64)

測試一下 Pytorch 的 reshape 的函數,Pytorch 中經常使用 view()函數來進行改變數組的 size:

c = b.view(4, 3, 2)
print(c)

# 輸出結果:
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]], dtype=torch.float64)

確實與 Numpy 的 reshape 方式一致

4. Pytorch 中幾個常見的有關維度的函數

4.1 squeeze() 和 unsqueeze()

這是一對操作相反的函數,分別用於 降維/升維 操作,例如:

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.size())
print(a)

b = a.unsqueeze(0)
print(b.size())
print(b)


# 輸出結果爲:
torch.Size([2, 3])
tensor([[1, 2, 3],
        [4, 5, 6]])
        
torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
         [4, 5, 6]]])

兩個函數都需要傳入 dim 參數,指定在哪一個維度上進行 壓縮 或者 升維,上面的例子中可以看到在第0維上增加了一個維度。當然這對數據本身並沒有什麼改變,但是 Pytorch 中對輸入的數據格式都做了要求,往往要求是一個四元組,所以很多時候都要對原始數據進行一些改變,所以就用得到這個函數了。

與之對應:

c = b.squeeze(0)
print(c.size())
print(c)

# 輸出:
torch.Size([2, 3])
tensor([[1, 2, 3],
        [4, 5, 6]])

不過因爲是壓縮,所以只有在維度爲1的時候才生效。

4.2 permute() 函數

從函數名就可以看出這個函數用於重新排列,這與 reshape 之類的函數區別就在於這是是直接進行重新排列,例如:

對於 [[1, 2],
      [3, 4],
      [5, 6]]
      
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = a.permute(1, 0)

print(b)
# 輸出:
tensor([[1, 3, 5],
        [2, 4, 6]])

可以看到這是相當於進行了一個轉置,也就是維度上的重新排列,如果得到 (2, 3) 的數組,使用view()函數則是:

c = a.view(2, 3)
print(c)

# 輸出:
tensor([[1, 2, 3],
        [4, 5, 6]])

可以很明顯地看到區別。

對於二維的數據,這個函數就是簡單地轉置,但是對於三維數據,有點抽象,舉個例子來說:

a = torch.arange(1, 25)
a = a.view(2, 3, 4)

print(a)
b = a.permute(2, 1, 0)

print(b)

# 輸出結果:

tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])
tensor([[[ 1, 13],
         [ 5, 17],
         [ 9, 21]],

        [[ 2, 14],
         [ 6, 18],
         [10, 22]],

        [[ 3, 15],
         [ 7, 19],
         [11, 23]],

        [[ 4, 16],
         [ 8, 20],
         [12, 24]]])

這裏相當於將所有的維度進行了相反的排列,看數據的排列就會發現這就像是矩陣的立體結構進行了旋轉,但是內部的數據排列又要改變位置,總之數據要符合原來的排列順序,這個也稱它爲高維矩陣的轉置,也就是熟知的二維數據轉置的推廣。(這已經超越我語言的極限了…╮(╯_╰)╭),基本就是這個意思。

更詳細的可以參考:Pytorch之permute函數

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章