Pytorch——Flatten, Reshape, and Squeeze Explained

Tensor operation types

1.Reshaping operations
2.Element-wise operations
3.Reduction operations
4.Access operations

Reshaping operations for tensors

Tensors are the primary ingredient that neural network programmers use to produce their product, intelligence.Our networks operate on tensors, after all, and this is why understanding a tensor’s shape and the available reshaping operations are super important.
Here is an example:

> t = torch.tensor([
    [1,1,1,1],
    [2,2,2,2],
    [3,3,3,3]
], dtype=torch.float32)

To determine the shape of this tensor, we look first at the rows 3 and then the columns 4, and so this tensor is a 3 x 4 rank 2 tensor. Remember, rank is a word that is commonly used and just means the number of dimensions present within the tensor.

  • We have two ways to get the shape:
    In PyTorch the size and shape of a tensor mean the same thing.
> t.size()
torch.Size([3, 4])

> t.shape
torch.Size([3, 4])
  • The rank of a tensor is equal to the length of the tensor’s shape.
> len(t.shape)
2
  • The number of elements inside a tensor (12 in our case) is equal to the product of the shape’s component values.
> torch.tensor(t.shape).prod()
tensor(12)

In PyTorch, there is a dedicated function for this:

> t.numel()
12

The number of elements contained within a tensor is important for reshaping because the reshaping must account for the total number of elements present. Reshaping changes the tensor’s shape but not the underlying data. Our tensor has 12 elements, so any reshaping must account for exactly 12 elements. 變形只改變張量的形狀而不改變內容。

Reshaping a tensor in PyTorch

> t.reshape([1,12])
tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])

> t.reshape([2,6])
tensor([[1., 1., 1., 1., 2., 2.],
        [2., 2., 3., 3., 3., 3.]])

> t.reshape([3,4])
tensor([[1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.]])

> t.reshape([4,3])
tensor([[1., 1., 1.],
        [1., 2., 2.],
        [2., 2., 3.],
        [3., 3., 3.]])

> t.reshape(6,2)
tensor([[1., 1.],
        [1., 1.],
        [2., 2.],
        [2., 2.],
        [3., 3.],
        [3., 3.]])

> t.reshape(12,1)
tensor([[1.],
        [1.],
        [1.],
        [1.],
        [2.],
        [2.],
        [2.],
        [2.],
        [3.],
        [3.],
        [3.],
        [3.]])

Using the reshape() function, we can specify the row x column shape that we are seeking. Notice how all of the shapes have to account for the number of elements in the tensor. In our example this is: rows * columns = 12 elements

We can use the intuitive words rows and columns when we are dealing with a rank 2 tensor. The underlying logic is the same for higher dimensional tenors even though we may not be able to use the intuition of rows and columns in higher dimensional spaces.
當我們處理一個二階張量時,我們可以用直觀的詞行和列。即使我們可能無法在高維空間中使用行和列的直覺,但對於高維的tensor來說,其基本邏輯是相同的。比如

> t.reshape(2,2,3)
tensor(
[
    [
        [1., 1., 1.],
        [1., 2., 2.]
    ],

    [
        [2., 2., 3.],
        [3., 3., 3.]
    ]
])

In this example, we increase the rank to 3, and so we lose the rows and columns concept. However, the product of the shape’s components (2,2,3) still has to be equal to the number of elements in the original tensor (12).
Note that PyTorch has another function that you may see called view() that does the same thing as the reshape() function, but don’t let these names through you off. No matter which deep learning framework we are using, these concepts will be the same.

Changing shape by squeezing and unsqueezing

The next way we can change the shape of our tensors is by squeezing and unsqueezing them.

  • Squeezing a tensor removes the dimensions or axes that have a length of one.
  • Unsqueezing a tensor adds a dimension with a length of one.
    These functions allow us to expand or shrink the rank (number of dimensions) of our tensor.這些函數允許我們擴展或縮小張量的秩(維數)。
> print(t.reshape([1,12]))
> print(t.reshape([1,12]).shape)
tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
torch.Size([1, 12])

> print(t.reshape([1,12]).squeeze())
> print(t.reshape([1,12]).squeeze().shape)
tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
torch.Size([12])

> print(t.reshape([1,12]).squeeze().unsqueeze(dim=0))
> print(t.reshape([1,12]).squeeze().unsqueeze(dim=0).shape)
tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
torch.Size([1, 12])

Flatten a tensor

A flatten operation on a tensor reshapes the tensor to have a shape that is equal to the number of elements contained in the tensor. This is the same thing as a 1d-array of elements.
對一個張量進行平坦化操作可以重塑這個張量,使其形狀等於張量中包含的元素的數目。這和元素的一維數組是一樣的。
Flattening a tensor means to remove all of the dimensions except for one.
Let’s create a Python function called flatten():
The flatten() function takes in a tensor t as an argument.

def flatten(t):
    t = t.reshape(1, -1)
    t = t.squeeze()
    return t

Since the argument t can be any tensor, we pass -1 as the second argument to the reshape() function.
In PyTorch, the -1 tells the reshape() function to figure out what the value should be based on the number of elements contained within the tensor. (-1會讓函數自動求出第二個值的大小,在第一個值和元素數量已知的情況下)
Remember, the shape must equal the product of the shape’s component values. This is how PyTorch can figure out what the value should be, given a 1 as the first argument.
因爲形狀大小必須要和不同維度元素數量的乘積相同,所有以1作爲第一個參數值,Pytorch可以計算出第二個數的值

Since our tensor t has 12 elements, the reshape() function is able to figure out that a 12 is required for the length of the second axis.
After squeezing, the first axis (axis-0) is removed, and we obtain our desired result, a 1d-array of length 12.

> t = torch.ones(4, 3)
> t
tensor([[1., 1., 1.],
    [1., 1., 1.],
    [1., 1., 1.],
    [1., 1., 1.]])

> flatten(t)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

We’ll see that flatten operations are required when passing an output tensor from a convolutional layer to a linear layer.

It is possible to flatten only specific parts of a tensor. For example, suppose we have a tensor of shape [2,1,28,28] for a CNN. This means that we have a batch of 2 grayscale images with height and width dimensions of 28 x 28, respectively.
Here, we can specifically flatten the two images to get the following shape: [2,1,784]. We could also squeeze off the channel axes to get the following shape: [2,784].

Concatenating tensors

We combine tensors using the cat() function, and the resulting tensor will have a shape that depends on the shape of the two input tensors.

> t1 = torch.tensor([
    [1,2],
    [3,4]
])
> t2 = torch.tensor([
    [5,6],
    [7,8]
])

# We can combine t1 and t2 row-wise (axis-0) in the following way:
> torch.cat((t1, t2), dim=0)
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

> torch.cat((t1, t2), dim=0).shape
torch.Size([4, 2])

# We can combine them column-wise (axis-1) like this:
> torch.cat((t1, t2), dim=1)
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])

> torch.cat((t1, t2), dim=1).shape
torch.Size([2, 4])

When we concatenate tensors, we increase the number of elements contained within the resulting tensor. This causes the component values within the shape (lengths of the axes) to adjust to account for the additional elements.

發佈了24 篇原創文章 · 獲贊 6 · 訪問量 3685
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章