pytorch中DataLoader生成器的使用記錄

在深度學習訓練數據集中,採用批量訓練時候基本都要使用生成器一批次一批次地把數據送入網絡,節省內存。在keras中有ImageDataGenerator,使用很方便。所以pytorch也有對應的生成器,這裏記錄一下學習筆記。個人感覺pytorch的生成器並沒有keras的使用方便。

keras中有ImageDataGenerator使用:https://blog.csdn.net/qq_35054151/article/details/101178662

pytorch中數據提取模塊主要有Dataset和DataLoader兩個部分:

1. DataLoader的函數定義如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

dataset:加載的數據集(Dataset對象) 
      batch_size:batch size 
      shuffle::是否將數據打亂 
      sampler: 樣本抽樣,後續會詳細介紹 
      num_workers:使用多進程加載的進程數,0代表不使用多進程 
      collate_fn: 如何將多個樣本數據拼接成一個batch,一般使用默認的拼接方式即可 
      pin_memory:是否將數據保存在pin memory區,pin memory中的數據轉到GPU會快一些 
      drop_last:dataset中的數據個數可能不是batch_size的整數倍,drop_last爲True會將多出來不足一個batch的數據丟棄

2. dataset

PyTorch讀取圖片,主要是通過Dataset類

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
	raise NotImplementedError
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

這裏重點是getitem字典函數,getitem接收一個index,然後返回圖片數據和標籤,這個index通常指的是一個list的index,這個list的每個元素就包含了圖片數據的路徑和標籤信息。一般的方法是將圖片的路徑和標籤信息存儲在一個txt中,然後從該txt中讀取。(這設計也不嫌麻煩!!!!)

搬運別人的代碼:

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
	fh = open(txt_path, 'r')
	imgs = []
	for line in fh:
		line = line.rstrip()
		words = line.split()
		imgs.append((words[0], int(words[1])))
		self.imgs = imgs 
		self.transform = transform
		self.target_transform = target_transform
def __getitem__(self, index):
	fn, label = self.imgs[index]
	img = Image.open(fn).convert('RGB') 
	if self.transform is not None:
		img = self.transform(img) 
	return img, label
def __len__(self):
	return len(self.imgs)

第一行:self.imgs 是一個list,也就是一開始提到的list,self.imgs的一個元素是一個str,包含圖片路徑,圖片標籤,這些信息是從txt文件中讀取

第二行:利用Image.open對圖片進行讀取,img類型爲 Image ,mode=‘RGB’

第三行與第四行: 對圖片進行處理,這個transform裏邊可以實現 減均值,除標準差,隨機裁剪,旋轉,翻轉,放射變換,等等操作,這個放在後面會詳細講解。

當Mydataset構建好,剩下的操作就交給DataLoder,在DataLoder中,會觸發Mydataset中的getiterm函數讀取一張圖片的數據和標籤,並拼接成一個batch返回,作爲模型真正的輸入。下一小節將會通過一個小例子,介紹DataLoder是如何獲取一個batch,以及一張圖片是如何被PyTorch讀取,最終變爲模型的輸入的。

參考鏈接:https://blog.csdn.net/weixin_40766438/article/details/100750633

               https://blog.csdn.net/u011995719/article/details/85102770

             https://blog.csdn.net/wwwww_bw/article/details/102911957

https://www.cnblogs.com/leokale-zz/p/11275800.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章