Numpy的一個函數:np.lib.stride_tricks.as_strided()
這個函數可以高效地切分一個數組爲不同的shape塊。可以參考官方文檔或者這篇寫得很好的博客。
描述:
numpy.lib.stride_tricks.as_strided(x, shape=None, strides=None, subok=False, writeable=True)
其中:x爲輸入數組;
shape爲目標形狀;
strides 需要一個例子來理解:
a = np.arange(9, dtype=np.int32).reshape(3,3)
print(a)
'''
[[0 1 2]
[3 4 5]
[6 7 8]]
'''
print(a.strides)
'''
(12, 4)
'''
這裏(12, 4)中的12表示在內存中a[n, 0]到a[n+1, 0]跨過多少byte,4表示在內存中a[n, 0]到a[n, 1]跨過多少byte。不同數據類型佔據的內存大小不同,總的來說一個32位的類型需要4byte,62位的類型需要8byte。 可以參考這裏,我整理成下面的表格。
數值 \ 數值類型 (單位:Byte) | int32 | int64 | float32 | float64 |
---|---|---|---|---|
[] | 96 | 96 | 96 | 96 |
[1] | 100 | 104 | 100 | 104 |
[5] | 116 | 136 | 116 | 136 |
這個表格的意思就是,一個空的數組[]就佔用內存96byte,因爲它存放了關於數組大小等信息。
只有一個32位整數的數組[1]佔用內存100byte,減去[]佔用的96byte,1就佔用4byte。
同理對於64整數的數組[1],一個元素佔用8byte。
上面的例子中,使用nbytes和itemsize可以直接查看元素佔據的內存。
# 查看數組a所有元素佔用內存大小,單位byte
a.nbytes
"""
36
"""
# 查看數組a每個元素佔用內存大小,單位byte
a.itemsize
"""
4
"""
如果我想將其切分爲一個(3, 3, 3)矩陣,我需要先padding其爲
[[0 0 0]
[0 1 2]
[3 4 5]
[6 7 8]
[0 0 0]]
結果如下:
[[[0 0 0]
[0 1 2]
[3 4 5]]
[[0 1 2]
[3 4 5]
[6 7 8]]
[[3 4 5]
[6 7 8]
[0 0 0]]]
需要設定的strides爲(36, 12, 4). 第三維最小的維度考慮最小單元內兩個相鄰元素的跨度,0->1, 3->4等都是4byte;第二維度考慮兩行之間的跨度,如[0 1 2]->[3 4 5]中首元素的跨度爲 byte。這裏的最高維跨度和第二維跨度一樣,如[[0 0 0] [0 1 2] [3 4 5]]-> [[0 1 2] [3 4 5] [6 7 8]]首元素跨度在原數組中只跨越了[0 0 0]三個元素,所以跨度也是 byte。