pytorch學習——torch.max

torch.max學習記錄:

首先定義數據:

import torch
a = torch.randn(2,3)
print("a:",a)

結果如下:

a: tensor([[-0.5658, -0.9736, -1.1753],
        [ 1.2006,  0.4078, -2.0542]])

 

torch.max(inputdim

按維度dim 返回最大值

torch.max(a,dim=0) 返回值爲一個元組,元組裏包含兩個,第一個值爲一個每一列中最大元素,第二個值爲最大元素在這一列的行索引

返回的又是列又是行的,這句話怎麼理解呢?

可以理解成:dim=0,第0個維度表示行,可以想象是你的手,從上往下擠壓(對應dim=0,第一行,第二行...從上往下的一個行方向),直到壓扁的一個過程(這個要理解清楚)。在此過程中,只保存每一列的最大值,同時記錄下這個最大值是第幾行(即行索引)。

       [ [-0.5658, -0.9736, -1.1753],
        [ 1.2006,  0.4078, -2.0542 ] ]

以這個數據來說,Size[2,3] ,使用torch.max(a,dim=0), 

結果爲:

(tensor([ 1.2006,  0.4078, -1.1753]), tensor([1, 1, 0]))

想象有隻手從上往下壓,把它壓扁,保留每一列最大值[ 1.2006, 0.4078, -1.1753] , 同時保留行索引。1.2006在第一行,0.4078在第一行,-1.1753在第0行,對應後面的 tensor([1, 1, 0]。

也可以通過索引,取出元組中的結果:

print("a.max(dim=0)[0]:",a.max(dim=0)[0])
print("a.max(dim=0)[1]:",a.max(dim=0)[1])

輸出:

a.max(dim=0)[0]: tensor([ 1.2006,  0.4078, -1.1753])
a.max(dim=0)[1]: tensor([1, 1, 0])

 

torch.max(a,diim=1) 返回每一行中最大值的那個元素,且返回其索引(返回最大元素在這一行的列索引

同理,可以想象爲是有隻手從左往右擠壓(對應dim=1,第一列,第二列...從左往右的一個列方向),直到壓扁的一個過程

貼上代碼:

print("a.max(dim=1)[0]:",a.max(dim=1)[0])
print("a.max(dim=1)[1]:",a.max(dim=1)[1])

輸出結果:

a.max(dim=1)[0]: tensor([-0.5658,  1.2006])
a.max(dim=1)[1]: tensor([0, 0])

torch.max與numpy.max的用法類似。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章