Pytorch——CNN Image Preparation Code Project - Learn to Extract, Transform, Load -------ETL

The project (Bird’s-eye view)

There are four general steps that we’ll be following as we move through this project:
1.Prepare the data
2.Build the model
3.Train the model
4.Analyze the model’s results

The ETL process

  • Extract data from a data source
  • Transform data into a desirable format
  • Load data into a suitable structure

PyTorch imports

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

在這裏插入圖片描述
The next imports are standard packages used for data science in Python:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
#from plotcm import plot_confusion_matrix

import pdb

torch.set_printoptions(linewidth=120)

Note that pdb is the Python debugger and the commented import is a local file that we’ll introduce in future posts for plotting the confusion matrix, and the last line sets the print options for PyTorch print statements.

Preparing our data using PyTorch

Extract – Get the Fashion-MNIST image data from the source.
Transform – Put our data into tensor form.
Load – Put our data into an object to make it easily accessible.

在這裏插入圖片描述

PyTorch Dataset class

train_set = torchvision.datasets.FashionMNIST(
    root='./data'  # 下載到當前文件夾的data文件夾內,若沒有則會創建data文件夾
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

Note that the root argument used to be ‘./data/FashionMNIST’, however, it has since changed due to torchvision updates.
在這裏插入圖片描述
Since we want our images to be transformed into tensors, we use the built-in transforms.ToTensor() transformation, and since this dataset is going to be used for training, we’ll name the instance train_set.
When we run this code for the first time, the Fashion-MNIST dataset will be downloaded locally. Subsequent calls check for the data before downloading it. Thus, we don’t have to worry about double downloads or repeated network calls.

PyTorch DataLoader class

train_loader = torch.utils.data.DataLoader(train_set
    ,batch_size=1000
    ,shuffle=True
)

batch_size (1000 in our case)
shuffle (True in our case)
num_workers (Default is 0 which means the main process will be used)

Exploring the data

To see how many images are in our training set, we can check the length of the dataset using the Python len() function:

> len(train_set)
60000

This 60000 number makes sense based on what we learned in the post on the Fashion-MNIST dataset.
Suppose we want to see the labels for each image. This can be done like so:

> train_set.targets
tensor([9, 0, 0, ..., 3, 0, 5])

The first image is a 9 and the next two are zeros. Remember from posts past, these values encode the actual class name or label. The 9 for example is an ankle boot while the 0 is a t-shirt.
If we want to see how many of each label exists in the dataset, we can use the PyTorch bincount() function like so:

> train_set.targets.bincount()
tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])

Class imbalance: Balanced and unbalanced datasets

This shows us that the Fashion-MNIST dataset is uniform 平均分佈的with respect to the number of samples in each class. This means we have 6000 samples for each class. As a result, this dataset is said to be balanced. ( Fashion-MNIST dataset 是一個均衡數據集)If the classes had a varying number of samples, we would call the set an unbalanced dataset.
Class imbalance is a common problem, but in our case, we have just seen that the Fashion-MNIST dataset is indeed balanced, so we need not worry about that for our project.

Accessing data in the training set

To access an individual element from the training set, we first pass the train_set object to Python’s iter() built-in function, which returns an object representing a stream of data.要訪問訓練集中的單個元素,我們首先將訓練集對象傳遞給Python的iter()內置函數,該函數返回一個表示數據流的對象。
With the stream of data, we can use Python built-in next() function to get the next data element in the stream of data. From this we are expecting to get a single sample, so we’ll name the result accordingly:

> sample = next(iter(train_set))
> len(sample)
2

After passing the sample to the len() function, we can see that the sample contains two items, and this is because the dataset contains image-label pairs.
Each sample we retrieve from the training set contains the image data as a tensor and the corresponding label as a tensor.
Since the sample is a sequence type, we can use sequence unpacking to assigned the image and the label. We will now check the type of the image and the label and see they are both torch.Tensor objects:

> type(image)
torch.Tensor

# Before torchvision 0.2.2
> type(label)
torch.Tensor
# Starting at torchvision 0.2.2
> type(label)
int

We’ll check the shape to see that the image is a 1 x 28 x 28 tensor while the label is a scalar valued tensor:

> image.shape
torch.Size([1, 28, 28]) 

> torch.tensor(label).shape
torch.Size([])

> image.squeeze().shape
torch.Size([28, 28])

注意:tensor shape--------torch.Size([1]) and scalar shape--------- torch.Size([])
torch.Size([0]) means a tensor of this size is 1-dimensional but has no elements.
Contrast this to a tensor of size torch.Size([1]), which means it is 1 dimensional and has one element.

Let’s plot the image now, and we’ll see why we squeezed the tensor in the first place. We first squeeze the tensor and then pass it to the imshow() function.
如果不squeeze的話,傳入imshow的圖像參數過多會報錯

> plt.imshow(image.squeeze(), cmap="gray")
> torch.tensor(label)
tensor(9)

PyTorch DataLoader: Working with batches of data

We’ll start by creating a new data loader with a smaller batch size of 10 so it’s easy to demonstrate what’s going on:

> display_loader = torch.utils.data.DataLoader(
    train_set, batch_size=10
)

There is one thing to notice when working with the data loader. If shuffle=True, then the batch will be different each time a call to next occurs.
With shuffle=True, the first samples in the training set will be returned on the first call to next.
The shuffle functionality is turned off by default.
在使用數據加載器時需要注意一件事。如果shuffle=True,則每次發生對next的調用時批處理將不同。使用shuffle=True,訓練集中的第一個樣本將在第一次調用next時返回。默認情況下,shuffle功能是關閉的。

# note that each batch will be different when shuffle=True
> batch = next(iter(display_loader))
> print('len:', len(batch))
len: 2

Let’s unpack the batch and take a look at the two tensors and their shapes:

> images, labels = batch

> print('types:', type(images), type(labels))
> print('shapes:', images.shape, labels.shape)
types: <class 'torch.Tensor'> <class 'torch.Tensor'>
shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])

Since batch_size=10, we know we are dealing with a batch of 10 images and 10 corresponding labels.
The size of each dimension in the tensor that contains the image data is defined by each of the following values:(batch size, number of color channels, image height, image width)

> images[0].shape
torch.Size([1, 28, 28])

> labels[0]
tensor(9)

To plot a batch of images, we can use the torchvision.utils.make_grid() function to create a grid that can be plotted like so:

> grid = torchvision.utils.make_grid(images, nrow=10)

> plt.figure(figsize=(15,15))
> plt.imshow(np.transpose(grid, (1,2,0)))
# 此處要還原爲載入圖像時基礎的shape,所以應把順序變爲[height, width, channel]
> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

在這裏插入圖片描述
Another way to do this:

> grid = torchvision.utils.make_grid(images, nrow=10)

> plt.figure(figsize=(15,15))
> plt.imshow(grid.permute(1,2,0))

> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

在這裏插入圖片描述

How to Plot Images Using PyTorch DataLoader

Here is another was to plot the images using the PyTorch DataLoader.

how_many_to_plot = 20

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=1, shuffle=True
)

mapping = {
    0:'Top', 1:'Trousers', 2:'Pullover', 3:'Dress', 4:'Coat'
    ,5:'Sandal', 6:'Shirt', 7:'Sneaker', 8:'Bag', 9:'Ankle Boot'
}

plt.figure(figsize=(50,50))
for i, batch in enumerate(train_loader, start=1):
    image, label = batch
    plt.subplot(10,10,i)
    fig = plt.imshow(image.reshape(28,28), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(mapping[label.item()], fontsize=28)
    if (i >= how_many_to_plot): break
plt.show()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章