【天池—街景字符編碼識別】Task 2 數據讀取與數據擴增

1 簡介

  本章主要內容爲數據讀取、數據擴增方法和Pytorch讀取賽題數據三個部分組成。

2 學習目標

  • 學習Python和Pytorch中圖像讀取
  • 學會擴增方法和Pytorch讀取賽題數據

3 圖像讀取

  在Python中有很多庫可以完成數據讀取的操作,比較常見的有PillowOpenCV

3.1 Pillow

  Pillow是Python圖像處理函式庫(Python Imaging Library,PIL)的一個分支。Pillow提供了常見的圖像讀取和處理的操作,而且可以與ipython notebook無縫集成,是應用比較廣泛的庫。其官方文檔是:https://pillow.readthedocs.io/en/stable/

3.1.1 安裝

  如果安裝了Anaconda,Pillow就已經可用了。否則,需要在命令行下通過pip安裝:

$ pip install pillow

  注意:如果遇到Permission denied安裝失敗,請加上sudo重試。

  此外,Anaconda的安裝見:Anaconda安裝、環境的配置以及Jupyter和Spyder的打開

3.1.2 基本操作

讀取圖片:

from PIL import Image

# 打開一個jpg圖像文件,注意是當前路徑:
im = Image.open('test.jpg')

更多操作可以見官方文檔:https://pillow.readthedocs.io/en/stable/

3.2 OpenCV

OpenCV的安裝見64位系統下 python3.7安裝OpenCV、OpenGL64位系統的同學要特別其安裝方式。
  OpenCV是一個跨平臺的計算機視覺庫,最早由Intel開源得來。OpenCV發展的非常早,擁有衆多的計算機視覺、數字圖像處理和機器視覺等功能。OpenCV在功能上比Pillow更加強大很多,學習成本也高很多。
  OpenCV包含了衆多的圖像處理的功能,OpenCV包含了你能想得到的只要與圖像相關的操作。此外OpenCV還內置了很多的圖像特徵處理算法,如關鍵點檢測、邊緣檢測和直線檢測等。
  OpenCV官網:https://opencv.org/
  OpenCV Github:https://github.com/opencv/opencv
  OpenCV 擴展算法庫:https://github.com/opencv/opencv_contrib

4 數據擴增方法

4.1 數據擴增介紹

  數據擴增(Data Augmentation),在深度學習中非常重要,數據擴增可以增加訓練集的樣本,同時也可以有效緩解模型過擬合的情況,也可以提高泛化能力

  • 數據擴增爲什麼有用?
      增加訓練集樣本的數量。在深度學習模型的訓練過程中,數據擴增是必不可少的環節。現有深度學習的參數非常多,一般的模型可訓練的參數量基本上都是萬到百萬級別,而訓練集樣本的數量很難有這麼多。
      其次數據擴增可以擴展樣本空間,假設現在的分類模型需要對汽車進行分類,左邊的是汽車A,右邊爲汽車B。如果不使用任何數據擴增方法,深度學習模型會從汽車車頭的角度來進行判別,而不是汽車具體的區別。
    在這裏插入圖片描述
  • 有哪些數據擴增方法?
      數據擴增方法有很多:從顏色空間尺度空間到樣本空間,同時根據不同任務數據擴增都有相應的區別。
      對於圖像分類,數據擴增一般不會改變標籤;
      對於物體檢測,數據擴增會改變物體座標位置;
      對於圖像分割,數據擴增會改變像素標籤。

4.2 常見的數據擴增方法

  在常見的數據擴增方法中,一般會從圖像顏色尺寸形態空間和像素等角度進行變換。當然不同的數據擴增方法可以自由進行組合,得到更加豐富的數據擴增方法。
  以torchvision爲例,常見的數據擴增方法包括:

  • transforms.CenterCrop 對圖片中心進行裁剪
  • transforms.ColorJitter 對圖像顏色的對比度、飽和度和零度進行變
  • transforms.FiveCrop 對圖像四個角和中心進行裁剪得到五分圖
  • transforms.Grayscale 對圖像進行灰度變換
  • transforms.Pad 使用固定值進行像素填充
  • transforms.RandomAffine 隨機仿射變換
  • transforms.RandomCrop 隨機區域裁剪
  • transforms.RandomHorizontalFlip 隨機水平翻轉
  • transforms.RandomRotation 隨機旋轉
  • transforms.RandomVerticalFlip 隨機垂直翻轉

4.3 常用的數據擴增庫

4.3.1 torchvision

  https://github.com/pytorch/vision
  pytorch官方提供的數據擴增庫,提供了基本的數據數據擴增方法,可以無縫與torch進行集成;但數據擴增方法種類較少,且速度中等

4.3.2 imgaug

  https://github.com/aleju/imgaug
  imgaug是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,且組合起來非常方便,速度較快

4.3.3 albumentations

  https://albumentations.readthedocs.io
  是常用的第三方數據擴增庫,提供了多樣的數據擴增方法,對圖像分類、語義分割、物體檢測和關鍵點檢測都支持,速度較快

5 PyTorch讀取數據(Dataset、DataLoder)

  在Pytorch中數據是通過Dataset進行封裝,並通過DataLoder進行並行讀取。所以我們只需要重載一下數據讀取的邏輯就可以完成數據的讀取。

import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中類別10爲數字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label,
          transforms.Compose([
              # 縮放到固定尺寸
              transforms.Resize((64, 128)),

              # 隨機顏色變換
              transforms.ColorJitter(0.2, 0.2, 0.2),

              # 加入隨機旋轉
              transforms.RandomRotation(5),

              # 將圖片轉換爲pytorch 的tesntor
              # transforms.ToTensor(),

              # 對圖像像素進行歸一化
              # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ]))

  接下來我們將在定義好的Dataset基礎上構建DataLoder,你可以會問有了Dataset爲什麼還要有DataLoder?其實這兩個是兩個不同的概念,是爲了實現不同的功能

  • Dataset:對數據集的封裝,提供索引方式的對數據樣本進行讀取
  • DataLoder:對Dataset進行封裝,提供批量讀取的迭代讀取
    加入DataLoder後,數據讀取代碼改爲如下:
import os, sys, glob, shutil, json
import cv2

from PIL import Image
import numpy as np

import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 原始SVHN中類別10爲數字0
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

# 對SVHNDataset進行封裝
train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批樣本個數
    shuffle=False, # 是否打亂順序
    num_workers=10, # 讀取的線程個數
)

for data in train_loader:
    break

在加入DataLoder後,數據按照批次獲取,每批次調用Dataset讀取單個樣本進行拼接。此時data的格式爲:
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
前者爲圖像文件,爲batchsize * chanel * height * width次序;後者爲字符標籤。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章