【PyTorch】張量 (Tensor) 的拆分與拼接 (split, chunk, cat, stack)

Overview

在 PyTorch 中,對張量 (Tensor) 進行拆分通常會用到兩個函數:

  • torch.split [按塊大小拆分張量]
  • torch.chunk [按塊數拆分張量]

而對張量 (Tensor) 進行拼接通常會用到另外兩個函數:

  • torch.cat [按已有維度拼接張量]
  • torch.stack [按新維度拼接張量]

它們的作用相似,但實際效果並不完全相同,以下會通過官方文檔及實例代碼來進行說明,以示區別

張量 (Tensor) 的拆分

torch.split 函數

torch.split(tensor, split_size_or_sections, dim = 0)

塊大小拆分張量
tensor 爲待拆分張量
dim 指定張量拆分的所在維度,即在第幾維對張量進行拆分
split_size_or_sections 表示在 dim 維度拆分張量時每一塊在該維度的尺寸大小 (int),或各塊尺寸大小的列表 (list)
指定每一塊的尺寸大小後,如果在該維度無法整除,則最後一塊會取餘數,尺寸較小一些
如:長度爲 10 的張量,按單位長度 3 拆分,則前三塊長度爲 3,最後一塊長度爲 1
函數返回:所有拆分後的張量所組成的 tuple
函數並不會改變原 tensor

torch.split 官方文檔

Splits the tensor into chunks. Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

Parameters:

  • tensor (Tensor) – tensor to split
  • split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk
  • dim (int) – dimension along which to split the tensor

實例代碼:

In [1]: X = torch.randn(6, 2)

In [2]: X
Out[2]:
tensor([[-0.3711,  0.7372],
        [ 0.2608, -0.1129],
        [-0.2785,  0.1560],
        [-0.7589, -0.8927],
        [ 0.1480, -0.0371],
        [-0.8387,  0.6233]])

In [3]: torch.split(X, 2, dim = 0)
Out[3]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [4]: torch.split(X, 3, dim = 0)
Out[4]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129],
         [-0.2785,  0.1560]]),
 tensor([[-0.7589, -0.8927],
         [ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [5]: torch.split(X, 4, dim = 0)
Out[5]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129],
         [-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [6]: torch.split(X, 1, dim = 1)
Out[6]:
(tensor([[-0.3711],
         [ 0.2608],
         [-0.2785],
         [-0.7589],
         [ 0.1480],
         [-0.8387]]),
 tensor([[ 0.7372],
         [-0.1129],
         [ 0.1560],
         [-0.8927],
         [-0.0371],
         [ 0.6233]]))

torch.chunk 函數

torch.chunk(input, chunks, dim = 0)

塊數拆分張量
input 爲待拆分張量
dim 指定張量拆分的所在維度,即在第幾維對張量進行拆分
chunks 表示在 dim 維度拆分張量時最後所分出的總塊數 (int),根據該塊數進行平均拆分
指定總塊數後,如果在該維度無法整除,則每塊長度向上取整,最後一塊會取餘數,尺寸較小一些,若餘數恰好爲 0,則會只分出 chunks - 1
如:

  • 長度爲 6 的張量,按 4 塊拆分,則只分出三塊,長度爲 2 (6 / 4 = 1.5 → 2)
  • 長度爲 10 的張量,按 4 塊拆分,則前三塊長度爲 3 (10 / 4 = 2.5 → 3),最後一塊長度爲 1

函數返回:所有拆分後的張量所組成的 tuple
函數並不會改變原 input

torch.chunk 官方文檔

Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.

Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.

Parameters:

  • input (Tensor) – the tensor to split
  • chunks (int) – number of chunks to return
  • dim (int) – dimension along which to split the tensor

實例代碼:

In [1]: X = torch.randn(6, 2)

In [2]: X
Out[2]:
tensor([[-0.3711,  0.7372],
        [ 0.2608, -0.1129],
        [-0.2785,  0.1560],
        [-0.7589, -0.8927],
        [ 0.1480, -0.0371],
        [-0.8387,  0.6233]])

In [3]: torch.chunk(X, 2, dim = 0)
Out[3]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129],
         [-0.2785,  0.1560]]),
 tensor([[-0.7589, -0.8927],
         [ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [4]: torch.chunk(X, 3, dim = 0)
Out[4]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [5]: torch.chunk(X, 4, dim = 0)
Out[5]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [6]: Y = torch.randn(10, 2)

In [6]: Y
Out[6]:
tensor([[-0.9749,  1.3103],
        [-0.4138, -0.8369],
        [-0.1138, -1.6984],
        [ 0.7512, -0.3417],
        [-1.4575, -0.4392],
        [-0.2035, -0.2962],
        [-0.7533, -0.8294],
        [ 0.0104, -1.3582],
        [-1.5781,  0.8594],
        [ 0.0286,  0.7611]])

In [7]: torch.chunk(Y, 4, dim = 0)
Out[7]:
(tensor([[-0.9749,  1.3103],
         [-0.4138, -0.8369],
         [-0.1138, -1.6984]]),
 tensor([[ 0.7512, -0.3417],
         [-1.4575, -0.4392],
         [-0.2035, -0.2962]]),
 tensor([[-0.7533, -0.8294],
         [ 0.0104, -1.3582],
         [-1.5781,  0.8594]]),
 tensor([[0.0286, 0.7611]]))

張量 (Tensor) 的拼接

torch.cat 函數

torch.cat(tensors, dim = 0, out = None)

已有維度拼接張量
tensors 爲待拼接張量的序列,通常爲 tuple
dim 指定張量拼接的所在維度,即在第幾維對張量進行拼接,除該拼接維度外,其餘維度上待拼接張量的尺寸必須相同
out 表示在拼接張量的輸出,也可直接使用函數返回值
函數返回:拼接後所得到的張量
函數並不會改變原 tensors

torch.cat 官方文檔

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().

torch.cat() can be best understood via examples.

Parameters:

  • tensors (sequence of Tensors) – any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension
  • dim (int, optional) – the dimension over which the tensors are concatenated
  • out (Tensor, optional) – the output tensor

實例代碼:(引用自官方文檔)

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), dim = 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), dim = 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497]])

torch.stack 函數

torch.stack(tensors, dim = 0, out = None)

新維度拼接張量
tensors 爲待拼接張量的序列,通常爲 tuple
dim 指定張量拼接的新維度對應已有維度的插入索引,即在原來第幾維的位置上插入新維度對張量進行拼接,待拼接張量在所有已有維度上的尺寸必須完全相同
out 表示在拼接張量的輸出,也可直接使用函數返回值
函數返回:拼接後所得到的張量
函數並不會改變原 tensors

torch.stack 官方文檔

Concatenates sequence of tensors along a new dimension.

All tensors need to be of the same size.

Parameters:

  • tensors (sequence of Tensors) – sequence of tensors to concatenate
  • dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive)
  • out (Tensor, optional) – the output tensor.

實例代碼:

In [1]: x = torch.randn(2, 3)

In [2]: x
Out[2]:
tensor([[-0.0288,  0.6936, -0.6222],
        [ 0.8786, -1.1464, -0.6486]])

In [3]: torch.stack((x, x, x), dim = 0)
Out[3]:
tensor([[[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]],

        [[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]],

        [[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]]])

In [4]: torch.stack((x, x, x), dim = 0).shape
Out[4]: torch.Size([3, 2, 3])

In [5]: torch.stack((x, x, x), dim = 1)
Out[5]:
tensor([[[-0.0288,  0.6936, -0.6222],
         [-0.0288,  0.6936, -0.6222],
         [-0.0288,  0.6936, -0.6222]],

        [[ 0.8786, -1.1464, -0.6486],
         [ 0.8786, -1.1464, -0.6486],
         [ 0.8786, -1.1464, -0.6486]]])

In [6]: torch.stack((x, x, x), dim = 1).shape
Out[6]: torch.Size([2, 3, 3])

希望能夠對大家有所幫助 ~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章