使用Pytorch中的,Dataset , DataLoader類去加載數據集:
import torch
from torchvision import transforms, datasets
import os,sys
from torch.utils.data import Dataset,DataLoader
from PIL import Image
import numpy as np
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std = [0.229, 0.224, 0.225])
])
class my_dataset(Dataset):
def __init__(self,img_path, mask_path, data_transform=None):
self.img_path = img_path
self.mask_path = mask_path
self.transforms = data_transform
self.img_list, self.mask_list = [],[]
for file in os.listdir(self.img_path):
img_path = os.path.join(self.img_path, file)
self.img_list.append(img_path)
for file in os.listdir(self.mask_path):
mask_path = os.path.join(self.mask_path, file)
self.mask_list.append(mask_path)
# self.imgs = np.stack(self.img_list,axis=0)
# self.mask = np.stack(self.mask_list, axis=0)
def __getitem__(self, item):
img = Image.open(self.img_list[item]).convert('RGB')
img = img.resize((224, 224), Image.ANTIALIAS)
img_np = np.array(img)
mask = Image.open(self.mask_list[item])
mask = mask.resize((224, 224), Image.ANTIALIAS)
mask_np = np.array(mask)
if self.transforms is not None:
img = self.transforms(img_np)
mask = self.transforms(mask_np)
return img, mask
def __len__(self):
return len(self.img_list)
img_path=r'G:\Pytorch\data\test_data'
mask_path = r'G:\Pytorch\data\test_gt'
train_dataset = my_dataset(img_path,mask_path,data_transform)
dataset_loader = DataLoader(train_dataset,
batch_size=1,
shuffle=True,
num_workers=1)
# 測試:
for batch_idx, (inputs, targets) in enumerate(dataset_loader):
print(inputs.shape)
print(targets.shape)
for i, item in enumerate(train_dataset):
data, label = item
print('data:', data)
print('label:', label)
網上大多數都是這樣去加載的,但我發現有時加載後的DataLoader實例,並不能夠讀取,會遇到各種錯誤
於是自己寫了一個數據加載相關的函數
讀取指定目錄文件夾下的所有圖片:
def read_directory(directory_name):
img_list = []
for filename in os.listdir(directory_name):
img = Image.open(directory_name+'/'+filename)
out = img.resize((224, 224), Image.ANTIALIAS)
img_np = np.array(out)
img_list.append(img_np)
image = np.stack(img_list,axis=0)
return image
將image_data 和 mask_data數據轉換成batch_size格式:
def get_batch_data(img,mask,batch):
import random
index = np.arange(len(img))
# np.random.shuffle(index)
id=0
image,label = [],[]
while True:
id +=batch
if id>len(img): break
temp_img = img[id-batch:id]
temp_mask = mask[id-batch:id]
image.append(temp_img)
label.append(temp_mask)
# img_batch,mask_batch = np.stack(image),np.stack(mask)
return image,label
根據文件目錄,獲取numpy格式 / pytorch格式的 圖片
def get_data(batch_size,train_path,mask_path):
# train_path = os.path.join(root_path,r'Data\task12_images_224.npy')
# mask_path = os.path.join(root_path,r'Data\task1_masks_224.npy')
img = np.load(train_path)
mask = np.load(mask_path)
img,mask = get_batch_data(img,mask,batch_size)
img,mask = np.array(img),np.array(mask)
# img,mask = np_to_tensor(img),np_to_tensor(mask)
print('load data success!')
return [img,mask]
對於get_data()處理後獲取的數據,可以這樣訪問:
for batch_idx in range(len(trainloader[0])):
inputs,targets = trainloader[0][batch_idx],trainloader[1][batch_idx]
# inputs,targets = np_to_tensor(inputs),np_to_tensor(targets)
inputs, targets = inputs.to(device), targets.to(device)
numpy格式圖片、torch格式圖片相互轉換
# np to Tensor(torch)
def np_to_tensor(img):
# img numpy type (224,224,3)
if len(img.shape)==3:
img = np.expand_dims(img,0)
img = img / 255.
img = (img - mean) / std
if len(img.shape)==4:
img = img.transpose(0, 3, 1, 2)
img = Variable(torch.from_numpy(img).to(device).float())
return img
# Tensor(torch) to np
def tensor_to_np(img_tensor):
img = img_tensor.data.cpu().numpy()
img = img.transpose(0, 2, 3, 1)
if img.shape[3]==3:
img = (img * std) + mean
img = img * 255.
img = np.clip(img, 0, 255.).astype(np.uint8)
return img