Pytorch & Numpy 對照表

PyTorch 裏面處理的最基本的操作對象就是 Tensor(張量),表示的是一個多維的矩陣,比如零維就是一個點,一維就是向量,二維就是一般的矩陣,多維就相當於一個多維的數組,這和 numpy 是對應的,而且 PyTorch 的 tensor 可以和 numpy 的 ndarray 相互轉換,

唯一不同的是 PyTorch 可以在 GPU 上運行
而 numpy 的 ndarray 只能在 CPU 上運行

前言

我們先介紹一下一些常用的不同數據類型的Tensor( torch.Tensor默認的是 torch.FloatTensor數據類型):

  • 16位整型 torch.ShortTensor
  • 32位整型 torch.IntTensor
  • 64位整型 torch.LongTensor
  • 16位浮點型 torch.HalfTensor
  • 32位浮點型 torch.FloatTensor
  • 64位浮點型 torch.DoubleTensor
  • torch.Tensor 默認 torch.FloatTensor

類型(Types)

Numpy PyTorch
np.ndarray torch.Tensor
np.float32 torch.float32; torch.float
np.float64 torch.float64; torch.double
np.float16 torch.float16; torch.half
np.int8 torch.int8
np.uint8 torch.uint8
np.int16 torch.int16; torch.short
np.int32 torch.int32; torch.int
np.int64 torch.int64; torch.long

索引

Numpy PyTorch
x[0] x[0]
x[:, 0] x[:, 0]
x[indices] x[indices]
np.take(x, indices) torch.take(x, torch.LongTensor(indices))
x[x != 0] x[x != 0]

ones and zeros

Numpy PyTorch
np.empty((2, 3)) torch.empty(2, 3)
np.empty_like(x) torch.empty_like(x)
np.eye torch.eye
np.identity torch.eye
np.ones torch.ones
np.ones_like torch.ones_like
np.zeros torch.zeros
np.zeros_like torch.zeros_like

從已知數據構造

Numpy PyTorch
np.array([[1, 2], [3, 4]]) torch.tensor([[1, 2], [3, 4]])
np.array([3.2, 4.3], dtype=np.float16) , np.float16([3.2, 4.3]) torch.tensor([3.2, 4.3], dtype=torch.float16)
x.copy() x.clone()
np.fromfile(file) torch.tensor(torch.Storage(file))
np.frombuffer
np.fromfunction
np.fromiter
np.fromstring
np.load torch.load
np.loadtxt
np.concatenate torch.cat

數值範圍

Numpy PyTorch
np.arange(10) torch.arange(10)
np.arange(2, 3, 0.1) torch.arange(2, 3, 0.1)
np.linspace torch.linspace
np.logspace torch.logspace

構造矩陣

Numpy PyTorch
np.diag torch.diag
np.tril torch.tril
np.triu torch.triu

參數

Numpy PyTorch
x.shape x.shape
x.strides x.stride()
x.ndim x.dim()
x.data x.data
x.size x.nelement()
x.dtype x.dtype

形狀(Shape)變換

Numpy PyTorch
x.reshape x.reshape; x.view
x.resize() x.resize_
x.resize_as_
x.transpose x.transpose or x.permute
x.flatten x.view(-1)
x.squeeze() x.squeeze()
x[:, np.newaxis]; np.expand_dims(x, 1) x.unsqueeze(1)

數據選擇

Numpy PyTorch
np.put
x.put x.put_
x = np.array([1, 2, 3]) x.repeat(2) # [1, 1, 2, 2, 3, 3] x = torch.tensor([1, 2, 3]) x.repeat(2) # [1, 2, 3, 1, 2, 3] x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) # [1, 1, 2, 2, 3, 3]
np.tile(x, (3, 2)) x.repeat(3, 2)
np.choose
np.sort sorted, indices = torch.sort(x, [dim])
np.argsort sorted, indices = torch.sort(x, [dim])
np.nonzero torch.nonzero
np.where torch.where
x[::-1]

數值計算

Numpy PyTorch
x.min x.min
x.argmin x.argmin
x.max x.max
x.argmax x.argmax
x.clip x.clamp
x.round x.round
np.floor(x) torch.floor(x); x.floor()
np.ceil(x) torch.ceil(x); x.ceil()
x.trace x.trace
x.sum x.sum
x.cumsum x.cumsum
x.mean x.mean
x.std x.std
x.prod x.prod
x.cumprod x.cumprod
x.all (x == 1).sum() == x.nelement()
x.any (x == 1).sum() > 0

數值比較

Numpy PyTorch
np.less x.lt
np.less_equal x.le
np.greater x.gt
np.greater_equal x.ge
np.equal x.eq
np.not_equal x.ne
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章