Pytorch中, torch.einsum

https://blog.csdn.net/a2806005024/article/details/96462827

3)Torch矩陣乘法。

print(a_tensor)
 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])
 
print(b_tensor)
 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])
 
# 'ik, kj -> ij'語義解釋如下:
# 輸入a_tensor: 2維數組,下標爲ik,
# 輸入b_tensor: 2維數組,下標爲kj,
# 輸出output:2維數組,下標爲ij。
# 隱含語義:輸入a,b下標中相同的k,是矩陣乘法的下標,對應上面的例子2的公式
output = torch.einsum('ik, kj -> ij', a_tensor, b_tensor)
 
print(output)
 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章