torch.max總結

torch.max被廣泛應用在評價model的最終預測性能中,其實這個問題大家已經總結得挺詳細了,例如:

https://blog.csdn.net/liuweiyuxiang/article/details/84668269

https://www.cnblogs.com/Archer-Fang/p/10651029.html

但是正如前面一個博文裏網友評論的那樣,似乎拿行、列來區分不太妥當。當然我也沒有想到更好的辦法來總結,根據例子就很快可以掌握了:

1. torch.max(a)是返回a中的最大值:

a=torch.tensor([[-2.1456, -0.6380,  1.3625],
                [-1.0394, -0.9641, -0.3667]])

print(torch.max(a))將得到:

tensor(1.3625) 

(另外,怎麼把這個轉成int呢?加上.item()即可)

另外torch.max(a)==a.max()

2. torch.max(a,1)返回的是每一行的最大值,還有最大值所在的索引:

torch.return_types.max(
values=tensor([ 1.3625, -0.3667]),
indices=tensor([2, 2]))

當然我們很多情況下只關心index(例如計算accuracy的時候),那麼這時候用

torch.max(a,1)[1] 或者 a.max(1)[1] 取出來即可:

tensor([2, 2])

再加上.numpy()就可以轉成array:

print(a.max(1)[1].numpy())得到:

[2 2]

基本的使用方法就是這些,但是有一個問題,爲什麼

torch.max(a,1)是每一行的最大值而torch.max(a,0)是每一列的最大值呢?

例如上面這個例子,print(torch.max(a,0))的輸出是:

torch.return_types.max(
values=tensor([-1.0394, -0.6380,  1.3625]),
indices=tensor([1, 0, 0]))

實際上可以這樣理解:0指的是在dimension 0中,各個vector之間比較,取到vector每一維的最大值。1指的是dimension 1中,每個逗號之間的元素進行比較,例如在[-2.1456, -0.6380,  1.3625]幾個數中間進行比較,得到的最大值就是一個標量了,然後這些最大值拼接成一個vector。

所以可以思考一下這種情況(雖然遇到的很少,所以其實我們可以按照0,1分別對應行和列來理解):

a=torch.tensor([[[-0.2389, -0.8487, -1.5907,  0.0732],
                 [-0.2159,  1.1064, -1.1317,  0.6457],
                 [ 0.8191,  1.0146,  1.0241,  0.7042]],
                 [[-0.8285,  0.3628,  1.4678,  0.7984],
                  [ 0.1009, -0.3307, -0.8245,  0.0044],
                  [-1.5041,  0.5067,  0.4085,  0.2126]]])

在這個時候,print(torch.max(a,0))得到的結果是:

values=tensor([[-0.2389,  0.3628,  1.4678,  0.7984],
        [ 0.1009,  1.1064, -0.8245,  0.6457],
        [ 0.8191,  1.0146,  1.0241,  0.7042]]),
indices=tensor([[0, 1, 1, 1],
        [1, 0, 1, 0],
        [0, 0, 0, 0]]))

大家可以看看是不是這麼回事。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章