torch.max被廣泛應用在評價model的最終預測性能中,其實這個問題大家已經總結得挺詳細了,例如:
https://blog.csdn.net/liuweiyuxiang/article/details/84668269
https://www.cnblogs.com/Archer-Fang/p/10651029.html
但是正如前面一個博文裏網友評論的那樣,似乎拿行、列來區分不太妥當。當然我也沒有想到更好的辦法來總結,根據例子就很快可以掌握了:
1. torch.max(a)是返回a中的最大值:
a=torch.tensor([[-2.1456, -0.6380, 1.3625],
[-1.0394, -0.9641, -0.3667]])
print(torch.max(a))將得到:
tensor(1.3625)
(另外,怎麼把這個轉成int呢?加上.item()即可)
另外torch.max(a)==a.max()
2. torch.max(a,1)返回的是每一行的最大值,還有最大值所在的索引:
torch.return_types.max(
values=tensor([ 1.3625, -0.3667]),
indices=tensor([2, 2]))
當然我們很多情況下只關心index(例如計算accuracy的時候),那麼這時候用
torch.max(a,1)[1] 或者 a.max(1)[1] 取出來即可:
tensor([2, 2])
再加上.numpy()就可以轉成array:
print(a.max(1)[1].numpy())得到:
[2 2]
基本的使用方法就是這些,但是有一個問題,爲什麼
torch.max(a,1)是每一行的最大值而torch.max(a,0)是每一列的最大值呢?
例如上面這個例子,print(torch.max(a,0))的輸出是:
torch.return_types.max(
values=tensor([-1.0394, -0.6380, 1.3625]),
indices=tensor([1, 0, 0]))
實際上可以這樣理解:0指的是在dimension 0中,各個vector之間比較,取到vector每一維的最大值。1指的是dimension 1中,每個逗號之間的元素進行比較,例如在[-2.1456, -0.6380, 1.3625]幾個數中間進行比較,得到的最大值就是一個標量了,然後這些最大值拼接成一個vector。
所以可以思考一下這種情況(雖然遇到的很少,所以其實我們可以按照0,1分別對應行和列來理解):
a=torch.tensor([[[-0.2389, -0.8487, -1.5907, 0.0732],
[-0.2159, 1.1064, -1.1317, 0.6457],
[ 0.8191, 1.0146, 1.0241, 0.7042]],
[[-0.8285, 0.3628, 1.4678, 0.7984],
[ 0.1009, -0.3307, -0.8245, 0.0044],
[-1.5041, 0.5067, 0.4085, 0.2126]]])
在這個時候,print(torch.max(a,0))得到的結果是:
values=tensor([[-0.2389, 0.3628, 1.4678, 0.7984],
[ 0.1009, 1.1064, -0.8245, 0.6457],
[ 0.8191, 1.0146, 1.0241, 0.7042]]),
indices=tensor([[0, 1, 1, 1],
[1, 0, 1, 0],
[0, 0, 0, 0]]))
大家可以看看是不是這麼回事。