torch.stack()函數

torch.stack()函數:

torch.stack(sequence, dim=0)

1.函數功能: 

           沿一個新維度對輸入張量序列進行連接,序列中所有張量應爲相同形狀;stack 函數返回的結果會新增一個維度,而stack()函數指定的dim參數,就是新增維度的(下標)位置。

2.參數列表:

           sequence:參與創建新張量的幾個張量;

           dim:新增維度的(下標)位置,當dim = -1時默認最後一個維度;

           返回值:輸出張量。

# 輸入張量信息:
# a=[i][j]
# b=[i][j]

c = stack((a,b), dim=0)

# 輸出張量信息:
# c[0][i][j] = a[i][j]
# c[1][i][j] = b[i][j]

3.應用舉例:

     1)

import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 0)
print(a)
print(b)
print(c)

# 輸出信息:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]])

     2)

import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 1)
print(a)
print(b)
print(c)

# 輸出信息:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [11, 22, 33]],

        [[ 4,  5,  6],
         [44, 55, 66]],

        [[ 7,  8,  9],
         [77, 88, 99]]])

     3)

import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 2)
print(a)
print(b)
print(c)

# 輸出信息:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1, 11],
         [ 2, 22],
         [ 3, 33]],

        [[ 4, 44],
         [ 5, 55],
         [ 6, 66]],

        [[ 7, 77],
         [ 8, 88],
         [ 9, 99]]])

     4)

import torch
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[11, 22, 33], [44, 55, 66], [77, 88, 99]]])
c = torch.stack([a, b], 3)
print(a)
print(b)
print(c)

# 輸出信息:
tensor([[[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]],

        [[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]])
tensor([[[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]],

        [[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]])
tensor([[[[ 1, 11],
          [ 2, 22],
          [ 3, 33]],

         [[ 4, 44],
          [ 5, 55],
          [ 6, 66]],

         [[ 7, 77],
          [ 8, 88],
          [ 9, 99]]],


        [[[ 1, 11],
          [ 2, 22],
          [ 3, 33]],

         [[ 4, 44],
          [ 5, 55],
          [ 6, 66]],

         [[ 7, 77],
          [ 8, 88],
          [ 9, 99]]]])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章