pytorch中index_select() 用法案例與解析

pytorch中index_select() 用法案例與解析

index_select(input, dim, index)
功能:在指定的維度dim上選取數據,不如選取某些行,列

參數介紹

  • 第一個參數input是要索引查找的對象
  • 第二個參數dim是要查找的維度,因爲通常情況下我們使用的都是二維張量,所以可以簡單的記憶: 0代表,1代表
  • 第三個參數index是你要索引的序列,它是一個tensor對象

下面簡單的看幾個案例
首先簡單的創建一個矩陣

x = torch.rand(5,4)
print(x)

tensor([[0.6198, 0.4874, 0.2826, 0.1908],
        [0.3884, 0.1720, 0.8688, 0.1023],
        [0.3972, 0.6469, 0.4800, 0.9155],
        [0.7255, 0.8646, 0.4741, 0.2681],
        [0.6407, 0.3080, 0.5546, 0.7326]])

如果我們想要查看矩陣的第一列信息
最簡單的方法就是 直接用切片取值

print(x[:,1]) 
tensor([0.4874, 0.1720, 0.6469, 0.8646, 0.3080])

如果使用index_select()方法則如下,三種得到的結果是一樣的
注意:
這裏的 dim 參數爲 1 代表列,
input參數根據具體情況來寫:
如果是torch.index_select那麼就像下面的第三條語句,需要寫上查找的對象x

print(x.index_select(1,torch.tensor([1]))) # 第 1 列
print(x.index_select(1,torch.tensor(1))) # 第 1 列
print(torch.index_select(x, 1,torch.tensor([1]))) # 第 1 列
tensor([[0.4874],
        [0.1720],
        [0.6469],
        [0.8646],
        [0.3080]])
tensor([[0.4874],
        [0.1720],
        [0.6469],
        [0.8646],
        [0.3080]])
tensor([[0.4874],
        [0.1720],
        [0.6469],
        [0.8646],
        [0.3080]])

再比如查找0,1列

print(x.index_select(0,torch.tensor([0,1]))) # 0,1行
print(x.index_select(1,torch.tensor([0,1]))) # 0,1列
tensor([[0.6198, 0.4874, 0.2826, 0.1908],
        [0.3884, 0.1720, 0.8688, 0.1023]])
tensor([[0.6198, 0.4874],
        [0.3884, 0.1720],
        [0.3972, 0.6469],
        [0.7255, 0.8646],
        [0.6407, 0.3080]])

我們可以創建一個多維的數據來試試
首先創建一個三維的矩陣 dim爲 0, 1, 2

x = torch.linspace(1,24,24).view(2,3,4)
print(x)
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]])

0維的 0,1,2

print(x.index_select(0,torch.tensor([0,1,2])))
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]])

1維的0 1

print(x.index_select(1,torch.tensor([0,1])))
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]]])
tensor([[[ 1.,  2.],
         [ 3.,  4.]],

        [[ 7.,  8.],
         [ 9., 10.]],

        [[13., 14.],
         [15., 16.]],

        [[19., 20.],
         [21., 22.]]])

2維的 0 1

print(x.index_select(2,torch.tensor([0,1])))
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章