文章目錄
不知道大家有沒有類似的問題,處理數據的時候很多時候會被各種數組的 shape 的變化搞暈,但是這方面的資料又不太好找,這裏記錄一點我遇到的這方面的知識點。
1.數組中數據存儲的結構
首先從數據的排列來看,從一個簡單的例子來看:
a = np.linspace(1, 24, 24).reshape(2, 3, 4)
print(a)
# 輸出結果:
[[[ 1. 2. 3. 4.]
[ 5. 6. 7. 8.]
[ 9. 10. 11. 12.]]
[[13. 14. 15. 16.]
[17. 18. 19. 20.]
[21. 22. 23. 24.]]]
首先 1 - 24 的24個數字,排列成 shape 爲 (2, 3, 4) 的數組,可以看到這種數據排列的方式顯然是按照從最低維(這裏是一個三維數組,0,1,2)2開始排列數據,這樣至少確定了在進行reshape的時候數據是如何填充的,我的猜想是reshape可以看作是兩步:
- 1.首先將數組整個拉平(既然填充的時候是現在高維度上進行填充,那麼拉平的過程就是反向的了)
- 然後按照新的 shape 按照上面的方式進行填充
2.數組的座標問題
舉個例子來看非常明顯:
a = np.linspace(1, 24, 24).reshape(2, 3, 4)
print(a[1][0][0])
print(a[0][1][0])
print(a[0][0][1])
# 可以看到輸出分別爲:
13.0
5.0
2.0
也就是最低維是行(向下增加),第二維是列,第三維是深。
3.對於Pytorch 的shape相關問題
Numpy 中的數組 array 也就是對應 Pytorch 的 tensor ,那麼上面的方式應該在 pytorch 中也是適用的,測試:
b = torch.from_numpy(a)
print(b)
print(b[1][0][0])
print(b[0][1][0])
print(b[0][0][1])
# 輸出爲:
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]], dtype=torch.float64)
tensor(13., dtype=torch.float64)
tensor(5., dtype=torch.float64)
tensor(2., dtype=torch.float64)
測試一下 Pytorch 的 reshape 的函數,Pytorch 中經常使用 view()函數來進行改變數組的 size:
c = b.view(4, 3, 2)
print(c)
# 輸出結果:
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]], dtype=torch.float64)
確實與 Numpy 的 reshape 方式一致
4. Pytorch 中幾個常見的有關維度的函數
4.1 squeeze() 和 unsqueeze()
這是一對操作相反的函數,分別用於 降維/升維 操作,例如:
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.size())
print(a)
b = a.unsqueeze(0)
print(b.size())
print(b)
# 輸出結果爲:
torch.Size([2, 3])
tensor([[1, 2, 3],
[4, 5, 6]])
torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
[4, 5, 6]]])
兩個函數都需要傳入 dim
參數,指定在哪一個維度上進行 壓縮 或者 升維,上面的例子中可以看到在第0維上增加了一個維度。當然這對數據本身並沒有什麼改變,但是 Pytorch 中對輸入的數據格式都做了要求,往往要求是一個四元組,所以很多時候都要對原始數據進行一些改變,所以就用得到這個函數了。
與之對應:
c = b.squeeze(0)
print(c.size())
print(c)
# 輸出:
torch.Size([2, 3])
tensor([[1, 2, 3],
[4, 5, 6]])
不過因爲是壓縮,所以只有在維度爲1的時候才生效。
4.2 permute() 函數
從函數名就可以看出這個函數用於重新排列,這與 reshape 之類的函數區別就在於這是是直接進行重新排列,例如:
對於 [[1, 2],
[3, 4],
[5, 6]]
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = a.permute(1, 0)
print(b)
# 輸出:
tensor([[1, 3, 5],
[2, 4, 6]])
可以看到這是相當於進行了一個轉置,也就是維度上的重新排列,如果得到 (2, 3) 的數組,使用view()函數則是:
c = a.view(2, 3)
print(c)
# 輸出:
tensor([[1, 2, 3],
[4, 5, 6]])
可以很明顯地看到區別。
對於二維的數據,這個函數就是簡單地轉置,但是對於三維數據,有點抽象,舉個例子來說:
a = torch.arange(1, 25)
a = a.view(2, 3, 4)
print(a)
b = a.permute(2, 1, 0)
print(b)
# 輸出結果:
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
tensor([[[ 1, 13],
[ 5, 17],
[ 9, 21]],
[[ 2, 14],
[ 6, 18],
[10, 22]],
[[ 3, 15],
[ 7, 19],
[11, 23]],
[[ 4, 16],
[ 8, 20],
[12, 24]]])
這裏相當於將所有的維度進行了相反的排列,看數據的排列就會發現這就像是矩陣的立體結構進行了旋轉,但是內部的數據排列又要改變位置,總之數據要符合原來的排列順序,這個也稱它爲高維矩陣的轉置,也就是熟知的二維數據轉置的推廣。(這已經超越我語言的極限了…╮(╯_╰)╭),基本就是這個意思。
更詳細的可以參考:Pytorch之permute函數