pytorch中index_select() 用法案例與解析
index_select(input, dim, index)
功能:在指定的維度dim上選取數據,不如選取某些行,列
參數介紹
- 第一個參數
input
是要索引查找的對象 - 第二個參數
dim
是要查找的維度,因爲通常情況下我們使用的都是二維張量,所以可以簡單的記憶:0
代表行,1
代表列 - 第三個參數
index
是你要索引的序列,它是一個tensor對象
下面簡單的看幾個案例
首先簡單的創建一個矩陣
x = torch.rand(5,4)
print(x)
tensor([[0.6198, 0.4874, 0.2826, 0.1908],
[0.3884, 0.1720, 0.8688, 0.1023],
[0.3972, 0.6469, 0.4800, 0.9155],
[0.7255, 0.8646, 0.4741, 0.2681],
[0.6407, 0.3080, 0.5546, 0.7326]])
如果我們想要查看矩陣的第一列信息
最簡單的方法就是 直接用切片取值
print(x[:,1])
tensor([0.4874, 0.1720, 0.6469, 0.8646, 0.3080])
如果使用index_select()
方法則如下,三種得到的結果是一樣的
注意:
這裏的 dim
參數爲 1 代表列,
input
參數根據具體情況來寫:
如果是torch.index_select
那麼就像下面的第三條語句,需要寫上查找的對象x
print(x.index_select(1,torch.tensor([1]))) # 第 1 列
print(x.index_select(1,torch.tensor(1))) # 第 1 列
print(torch.index_select(x, 1,torch.tensor([1]))) # 第 1 列
tensor([[0.4874],
[0.1720],
[0.6469],
[0.8646],
[0.3080]])
tensor([[0.4874],
[0.1720],
[0.6469],
[0.8646],
[0.3080]])
tensor([[0.4874],
[0.1720],
[0.6469],
[0.8646],
[0.3080]])
再比如查找0,1列
print(x.index_select(0,torch.tensor([0,1]))) # 0,1行
print(x.index_select(1,torch.tensor([0,1]))) # 0,1列
tensor([[0.6198, 0.4874, 0.2826, 0.1908],
[0.3884, 0.1720, 0.8688, 0.1023]])
tensor([[0.6198, 0.4874],
[0.3884, 0.1720],
[0.3972, 0.6469],
[0.7255, 0.8646],
[0.6407, 0.3080]])
我們可以創建一個多維的數據來試試
首先創建一個三維的矩陣 dim爲 0, 1, 2
x = torch.linspace(1,24,24).view(2,3,4)
print(x)
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]])
0維的 0,1,2
print(x.index_select(0,torch.tensor([0,1,2])))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]])
1維的0 1
print(x.index_select(1,torch.tensor([0,1])))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]]])
tensor([[[ 1., 2.],
[ 3., 4.]],
[[ 7., 8.],
[ 9., 10.]],
[[13., 14.],
[15., 16.]],
[[19., 20.],
[21., 22.]]])
2維的 0 1
print(x.index_select(2,torch.tensor([0,1])))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]])