Pytorch學習(四)---Pytorch數據處理工具箱

一. 數據處理工具箱概述

Pytorch設計數據處理(數據裝載、數據預處理、數據增強等)主要工具包及相互關係如圖所示。
左邊的是torch.utils.data工具包,它包括以下四個類
1.Dataset:是一個抽象類,其他數據集需要繼承這個類,並且覆寫其中的兩個方法(_ getitem_、_ len_)。
2.Dataloader:定義一個新的迭代器,實現批量(batch)讀取,到alciuju並提供並行加速等功能。
3.random_split:把數據集隨即拆分未給定長度的非重疊的新數據集。
4.*Sampler:多種採樣函數
圖中間是Pytorch的可視化處理工具Torchvision,其是Pytorch的一個視覺處理工具包,獨立於Pytorch,需要另外安裝。
它包括4各類,主要是:
1.datasets:提供常用的數據集加載,設計上都哦繼承自torch.utils.data.Dataset,主要包括MNIST、CIFAR10/100、ImageNet和COCO等。
2.Models:提供深度學習中各種經典的網絡結構以及訓練好的模型(如果選擇pretrained=True)。
3.transforms:常用的數據預處理操作,主演包括對Tensor及PIL Image對象的操作
4.utils:含兩個函數,一個是make_grid,它能將多張圖片集拼接在一個網絡中;另一個是save_img,它能將Tensor保存成圖片。

utils.data簡介

utils.data包括Dataset和Dataloader。torch.utils.data.Dataset未抽象類。自定義數據集需要繼承這個類,並實現兩個函數,一個是_getitem_,一個是 len,前者通過給定索引獲取數據和標籤,後者提供數據的大小(size)。_getitem_一次只能獲取一個數據,所以要通過torch.utils.data.DataLoader來定義一個新的迭代器,實現batch讀取。

1.導入需要的模塊

import torch
from torch.utils import data
import numpy as np

2.定義獲取數據記得類。
該類繼承類Dataset,自定義一個數據集及對應的標籤。

class TestDataset(data.Dataset):
	def __init__()self:
		self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
		self.Lable= ap,asarray([0,1,0,1,2])
	def __getitem__(self,index):
		#把numpy轉換爲Tensor
		txt = torch.from_numpy(self.Data(index))
		label = torch.tensor(self.Label(index))
		retuen txt,label
	def __len__(self):
		return len(self.Data)
		

3獲取數據集中的數據

Test= TestDataset()
print(Test[2]) #相當於調用__getitem__(2)
print(Test.__len__())

**以上數據以tuple返回,每次只返回一個樣本。實際上,Dataset只負責數據的抽取,調用一次__getitem__只返回一個樣本。**如果希望批處理,同時還要進行shuffle和並行加速等操作,可選擇DataLoader。
DataLoader的格式定義爲:

data.DataLoader(
	dataset, #加載的數據集
	batach_size=1, #批大小
	shuffle=False, # 是否將數據打亂
	sampler=Nonebatch_sampler=None, # 樣本抽樣
	num_workers=0,# 使用多進程加載的進程數,0表示不使用多進程
	)
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
for i,traindata in enumerate(test_loader):
	print("i:",i)
	Data,Label=traindata
	print(Data,Label)

.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章