解析tensor.expand()爲什麼不會分配新的內存而只是在存在的張量上創建新的視圖
關於tensor.expand()函數的介紹見前文pytorch中的expand()和expand_as()函數。
在faster_rcnn源碼的rpn模塊中反覆出現了view().expand().contiguous()的組合,expand()函數的功能是用來擴展張量中某維數據的尺寸,這並不難理解,但在expand()函數的解析中指出“擴展張量不會分配新的內存,只是在存在的張量上創建一個新的視圖”,這句話使我有一些不理解,理解這句話的含義首先需要了解張量的存儲機制。
1.張量的存儲機制:
參考資料:由淺入深地瞭解張量
張量可以理解爲一段內存的視圖,多個張量可以對相同的存儲進行索引,索引可以不同,但是底層內存完全相同。
張量通過使用形狀(shape)/步長(stride)和存儲偏移量來對相應的內存進行索引。
1)形狀:表示各個維度上的元素個數,如(3, 2, 2);
2)步長:當索引在每個維度增加1時,必須跳過的內存中元素個數;
3)存儲偏移量:存儲中對應於張量第一個元素的index。
舉例,如一段內存(1, 2, 3, 4, 5, 6),以shape=(2, 3),stride = (3, 1),index = 0來索引得到的張量爲:
[[1, 2, 3],
[4, 5, 6]]
而以shape=(3, 2),stride = (1, 3),index = 0來索引同一段內存,得到的張量爲:
[[1, 4],
[2, 5],
[3, 6]]
2.tensor.expand()爲什麼不會分配新的內存而只是在存在的張量上創建新的視圖?
瞭解了張量的存儲機制,已經可以理解了轉置矩陣或者view()函數這些不改變元素個數的操作可以通過在原內存上創建新的視圖實現,可是expand()擴展了張量的維度尺寸,也就是使張量的元素個數出現了增加,那麼是如何通過不分配新的內存只是創建新的視圖實現呢?
是通過將新的張量視圖中的步長設爲0實現的,如以下的例子可知,只要將stride設爲0,將反覆索引內存中的同一個元素,也就實現了expand()函數將張量的某個“1”維在不分配新內存情況下擴展爲任意數值的更多維。
import torch
a = torch.tensor([[1], [2], [3]])
print(a.size())
c = a.expand(3, 3)
print(a)
print(c)
# 輸出信息:
torch.Size([3, 1])
tensor([[1],
[2],
[3]])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])