pytorch 分割一個tensor並求平均

pytorch的torch.add()torch.split()函數

import torch
# outputs是一個[batch, seq, 40]維的tensor,把outputs分割成兩個[batch, seq, 20]的tensor,並每個元素求平均值
add = torch.add(*torch.split(outputs, 20, dim=2)) / 2
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章