torch.max學習記錄:
首先定義數據:
import torch
a = torch.randn(2,3)
print("a:",a)
結果如下:
a: tensor([[-0.5658, -0.9736, -1.1753],
[ 1.2006, 0.4078, -2.0542]])
torch.
max
(input, dim)
按維度dim 返回最大值
torch.max(a,dim=0) 返回值爲一個元組,元組裏包含兩個值,第一個值爲一個每一列中最大元素,第二個值爲最大元素在這一列的行索引
返回的又是列又是行的,這句話怎麼理解呢?
可以理解成:dim=0,第0個維度表示行,可以想象是你的手,從上往下擠壓(對應dim=0,第一行,第二行...從上往下的一個行方向),直到壓扁的一個過程(這個要理解清楚)。在此過程中,只保存每一列的最大值,同時記錄下這個最大值是第幾行(即行索引)。
[ [-0.5658, -0.9736, -1.1753],
[ 1.2006, 0.4078, -2.0542 ] ]
以這個數據來說,Size[2,3] ,使用torch.max(a,dim=0),
結果爲:
(tensor([ 1.2006, 0.4078, -1.1753]), tensor([1, 1, 0]))
想象有隻手從上往下壓,把它壓扁,保留每一列最大值[ 1.2006, 0.4078, -1.1753] , 同時保留行索引。1.2006在第一行,0.4078在第一行,-1.1753在第0行,對應後面的 tensor([1, 1, 0]。
也可以通過索引,取出元組中的結果:
print("a.max(dim=0)[0]:",a.max(dim=0)[0])
print("a.max(dim=0)[1]:",a.max(dim=0)[1])
輸出:
a.max(dim=0)[0]: tensor([ 1.2006, 0.4078, -1.1753])
a.max(dim=0)[1]: tensor([1, 1, 0])
torch.max(a,diim=1) 返回每一行中最大值的那個元素,且返回其索引(返回最大元素在這一行的列索引)
同理,可以想象爲是有隻手從左往右擠壓(對應dim=1,第一列,第二列...從左往右的一個列方向),直到壓扁的一個過程
貼上代碼:
print("a.max(dim=1)[0]:",a.max(dim=1)[0])
print("a.max(dim=1)[1]:",a.max(dim=1)[1])
輸出結果:
a.max(dim=1)[0]: tensor([-0.5658, 1.2006])
a.max(dim=1)[1]: tensor([0, 0])
torch.max與numpy.max的用法類似。