Pytorch的gather()和scatter()
1.gather()
gather是取的意思,意爲把某一tensor矩陣按照一個索引序列index取出,組成一個新的矩陣。
gather(input,dim,index)
參數:
- input是要取值的矩陣
- dim指操作的維度,0爲豎向操作即按行操作,1爲橫向操作即按列操作
- index爲索引序列
下面這個例子是按行取出第一行的’0號元素’,'0行元素’組成新的第一行;
再取出第二行的‘1號元素’,‘0號元素’組成新的第二行
a = torch.Tensor([[1,2],[3,4]])
b = torch.gather(a, 1, torch.LongTensor([[0,0],[1,0]]))
print(a)
1 2
3 4
print(b)
1 1
4 3
2.scatter_()
這個是‘放’的意思,即把原tensor矩陣的元素按照新索引index的序列位置,放到新的矩陣中。
scatter_(dim,index,src)
參數:
- src 是要取出元素的矩陣
注意要放置的矩陣不在參數中,其直接調用這個函數。
下例就是按索引[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]位置,把隨機矩陣a中元素放置到全0矩陣torch.zeros(3,5)中。
a = torch.rand(2, 5)
print(a)
b = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)
print(b)
其中dim=0
3.參考:
https://zhuanlan.zhihu.com/p/59346637