Pytorch的gather()和scatter()

Pytorch的gather()和scatter()

1.gather()

gather是取的意思,意爲把某一tensor矩陣按照一個索引序列index取出,組成一個新的矩陣。

gather(input,dim,index)
參數:

  • input是要取值的矩陣
  • dim指操作的維度,0爲豎向操作即按行操作,1爲橫向操作即按列操作
  • index爲索引序列

下面這個例子是按行取出第一行的’0號元素’,'0行元素’組成新的第一行;
再取出第二行的‘1號元素’,‘0號元素’組成新的第二行

a = torch.Tensor([[1,2],[3,4]])
b = torch.gather(a, 1, torch.LongTensor([[0,0],[1,0]]))
print(a)
1 2 
3 4
print(b)
1 1
4 3

2.scatter_()

這個是‘放’的意思,即把原tensor矩陣的元素按照新索引index的序列位置,放到新的矩陣中。

scatter_(dim,index,src)
參數:

  • src 是要取出元素的矩陣

注意要放置的矩陣不在參數中,其直接調用這個函數。

下例就是按索引[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]位置,把隨機矩陣a中元素放置到全0矩陣torch.zeros(3,5)中。

a = torch.rand(2, 5)
print(a)
b = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)
print(b) 

其中dim=0
dim=0

3.參考:

https://zhuanlan.zhihu.com/p/59346637

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章