https://blog.csdn.net/a2806005024/article/details/96462827
3)Torch矩陣乘法。
print(a_tensor)
tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
print(b_tensor)
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
# 'ik, kj -> ij'語義解釋如下:
# 輸入a_tensor: 2維數組,下標爲ik,
# 輸入b_tensor: 2維數組,下標爲kj,
# 輸出output:2維數組,下標爲ij。
# 隱含語義:輸入a,b下標中相同的k,是矩陣乘法的下標,對應上面的例子2的公式
output = torch.einsum('ik, kj -> ij', a_tensor, b_tensor)
print(output)
tensor([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])