pytorch gather

y
Out[34]: tensor([0, 2])
y_hat
Out[35]: 
tensor([[0.1000, 0.3000, 0.6000],
        [0.3000, 0.2000, 0.5000]])

y_hat.gather(1, y.view(-1, 1))

聚合方向y_hat的維度1,聚合位置:

a[0][y.view(-1,1)[0]] = 0.1

a[1][y.view(-1,1)[1]] = 0.5

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章