torch.stack()函數:
torch.stack(sequence, dim=0)
1.函數功能:
沿一個新維度對輸入張量序列進行連接,序列中所有張量應爲相同形狀;stack 函數返回的結果會新增一個維度,而stack()函數指定的dim參數,就是新增維度的(下標)位置。
2.參數列表:
sequence:參與創建新張量的幾個張量;
dim:新增維度的(下標)位置,當dim = -1時默認最後一個維度;
返回值:輸出張量。
# 輸入張量信息:
# a=[i][j]
# b=[i][j]
c = stack((a,b), dim=0)
# 輸出張量信息:
# c[0][i][j] = a[i][j]
# c[1][i][j] = b[i][j]
3.應用舉例:
1)
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 0)
print(a)
print(b)
print(c)
# 輸出信息:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[11, 22, 33],
[44, 55, 66],
[77, 88, 99]]])
2)
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 1)
print(a)
print(b)
print(c)
# 輸出信息:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[11, 22, 33]],
[[ 4, 5, 6],
[44, 55, 66]],
[[ 7, 8, 9],
[77, 88, 99]]])
3)
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 2)
print(a)
print(b)
print(c)
# 輸出信息:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 11],
[ 2, 22],
[ 3, 33]],
[[ 4, 44],
[ 5, 55],
[ 6, 66]],
[[ 7, 77],
[ 8, 88],
[ 9, 99]]])
4)
import torch
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[11, 22, 33], [44, 55, 66], [77, 88, 99]]])
c = torch.stack([a, b], 3)
print(a)
print(b)
print(c)
# 輸出信息:
tensor([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]])
tensor([[[11, 22, 33],
[44, 55, 66],
[77, 88, 99]],
[[11, 22, 33],
[44, 55, 66],
[77, 88, 99]]])
tensor([[[[ 1, 11],
[ 2, 22],
[ 3, 33]],
[[ 4, 44],
[ 5, 55],
[ 6, 66]],
[[ 7, 77],
[ 8, 88],
[ 9, 99]]],
[[[ 1, 11],
[ 2, 22],
[ 3, 33]],
[[ 4, 44],
[ 5, 55],
[ 6, 66]],
[[ 7, 77],
[ 8, 88],
[ 9, 99]]]])