一. 數據處理工具箱概述
Pytorch設計數據處理(數據裝載、數據預處理、數據增強等)主要工具包及相互關係如圖所示。
左邊的是torch.utils.data工具包,它包括以下四個類
1.Dataset:是一個抽象類,其他數據集需要繼承這個類,並且覆寫其中的兩個方法(_ getitem_、_ len_)。
2.Dataloader:定義一個新的迭代器,實現批量(batch)讀取,到alciuju並提供並行加速等功能。
3.random_split:把數據集隨即拆分未給定長度的非重疊的新數據集。
4.*Sampler:多種採樣函數
圖中間是Pytorch的可視化處理工具Torchvision,其是Pytorch的一個視覺處理工具包,獨立於Pytorch,需要另外安裝。
它包括4各類,主要是:
1.datasets:提供常用的數據集加載,設計上都哦繼承自torch.utils.data.Dataset,主要包括MNIST、CIFAR10/100、ImageNet和COCO等。
2.Models:提供深度學習中各種經典的網絡結構以及訓練好的模型(如果選擇pretrained=True)。
3.transforms:常用的數據預處理操作,主演包括對Tensor及PIL Image對象的操作
4.utils:含兩個函數,一個是make_grid,它能將多張圖片集拼接在一個網絡中;另一個是save_img,它能將Tensor保存成圖片。
utils.data簡介
utils.data包括Dataset和Dataloader。torch.utils.data.Dataset未抽象類。自定義數據集需要繼承這個類,並實現兩個函數,一個是_getitem_,一個是 len,前者通過給定索引獲取數據和標籤,後者提供數據的大小(size)。_getitem_一次只能獲取一個數據,所以要通過torch.utils.data.DataLoader來定義一個新的迭代器,實現batch讀取。
1.導入需要的模塊
import torch
from torch.utils import data
import numpy as np
2.定義獲取數據記得類。
該類繼承類Dataset,自定義一個數據集及對應的標籤。
class TestDataset(data.Dataset):
def __init__()self:
self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
self.Lable= ap,asarray([0,1,0,1,2])
def __getitem__(self,index):
#把numpy轉換爲Tensor
txt = torch.from_numpy(self.Data(index))
label = torch.tensor(self.Label(index))
retuen txt,label
def __len__(self):
return len(self.Data)
3獲取數據集中的數據
Test= TestDataset()
print(Test[2]) #相當於調用__getitem__(2)
print(Test.__len__())
**以上數據以tuple返回,每次只返回一個樣本。實際上,Dataset只負責數據的抽取,調用一次__getitem__只返回一個樣本。**如果希望批處理,同時還要進行shuffle和並行加速等操作,可選擇DataLoader。
DataLoader的格式定義爲:
data.DataLoader(
dataset, #加載的數據集
batach_size=1, #批大小
shuffle=False, # 是否將數據打亂
sampler=Nonebatch_sampler=None, # 樣本抽樣
num_workers=0,# 使用多進程加載的進程數,0表示不使用多進程
)
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
for i,traindata in enumerate(test_loader):
print("i:",i)
Data,Label=traindata
print(Data,Label)
.