pytorch 中涉及到矩陣之間的乘法(torch.mul, *, torch.mm, torch.matmul, @)

最近在學習pytorch,過程中遇到一些問題,這裏權當筆記記錄下來,同時也供大家參考。

下面簡單回顧一下矩陣中的乘法:(嚴謹的說,其實應該說是矩陣乘法和矩陣內積)
1、矩陣乘法
  矩陣乘法也就是我們常說的矩陣向量積(也稱矩陣外積矩陣叉乘
      它要求前一個矩陣的行數等於後一個矩陣的列數,其計算方法是計算結果的每一行元素爲前一個矩陣的每一行元素與後一個矩陣的每一列對應元素相乘,之後求和。下面232*3矩陣與353*5矩陣爲例:
[111111]×[123451234512345]=[36912153691215] \begin{gathered} \begin{bmatrix} 1 & 1 & 1\\ 1 & 1 & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 1 & 2 & 3 & 4 & 5 \\ 1 & 2 & 3 & 4 & 5 \end{bmatrix} \end{gathered}=\begin{bmatrix} 3 & 6 & 9 & 12 & 15 \\ 3 & 6 & 9 & 12 & 15 \end{bmatrix}
其計算方法爲:
11+11+11=a11=3,   12+12+12=a12=61*1+1*1+1*1=a11=3, \, \,\,1*2+1*2+1*2=a12=6……
其中a11爲第一行第一個元素,以此類推
2、矩陣內積
  矩陣點法也就是我們常說的矩陣點乘
       即矩陣的對應元素相乘,故它要求兩個矩陣形狀一樣,下面232*3矩陣與232*3矩陣爲例:
[111111].[123456]=[123456] \begin{gathered} \begin{bmatrix} 1 & 1 & 1\\ 1 & 1 & 1 \end{bmatrix} . \begin{bmatrix}1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \end{gathered}=\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}

  在進入正題之前,先扯點兒閒篇——大家應該都知道numpy(至少聽說過,python的一個數值計算庫,pytorch不火的時候,numpy還是很好用的),而pytorch,主要特點是可以使用GPU加速運算,但是計算上和numpy有很多類似之處,那好,介紹pytorch的矩陣乘法之前,先說說numpy中ndarray中矩陣的乘法:
  numpyt中點乘使用*或者np.multiply(),而叉乘使用@, np.dot(), np.matmul()
測試測序如下:

import numpy as np

print("numpy")
A = np.array([[1, 2, 3, 6], [2, 3, 4, 3], [2, 3, 4, 4]])
B = np.array([[1, 0, 1, 4], [2, 1, -1, 0], [2, 1, 5, 0]])
C = np.array([[1, 0, 3], [0, 1, 2], [-1, 0, 1], [-1, 0, 1]])

# 對應位置相乘,點乘
print("矩陣對應元素相乘 點乘")
print("*運算符\n", A*B)
print("np.multiply\n", np.multiply(A, B))

print("矩陣相乘 叉乘")
print("A.dot\n", A.dot(C))  # 矩陣乘法
print("@運算符\n", A@C)
print("np.matmul\n", np.matmul(A, C), '\n')

有人會問這裏的dot和matmul函數有什麼區別

請移步numpy中dot和matmul的區別

  而pytorch中用法略有不同,其中點乘使用*或者np.mul(),而叉乘使用@, torch.mm(), torch.matmul()(注意這裏沒有dot函數,使用torch.mm函數)

import torch
print("pytorch")
a = torch.ones(2, 3)
c = torch.FloatTensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
#b = torch.randint(1, 9, (2, 3))
b = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])

print("矩陣對應元素相乘 點乘")
print("*運算符\n", a * b)
print("torch.mul\n", torch.mul(a, b))

print("矩陣相乘 叉乘")
print("@運算符\n", a@c)
print("torch.mm\n", torch.mm(a, c))
print("torch.matmul\n", torch.matmul(a, c), "\n")

輸出結果:
在這裏插入圖片描述
下面又有一個問題,torch.mm()和torch.matmul()到底有什麼區別?
可以參考官網教程
如果你懶得看,你可以看下面這兩張我從官網上截的圖
torch.mm函數用法
在這裏插入圖片描述
當然了,如果還是難以理解的話,請移步這裏

參考:
[1]https://blog.csdn.net/She_Said/article/details/98034841
[2]https://www.jb51.net/article/177406.htm
[3]https://pytorch.org/docs/stable/torch.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章