數據類
數據集主要是
torch.utils.data類
要實現加載和預處理數據可分爲以下兩個步驟:
1.加載數據集(Dateset)
1.1 自帶數據集(Mnist/FashionMnist等)
加載時需要完成數據格式的轉換(transform).
一種加載方法是用自帶的數據集,來自torchvision大類:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.FashionMNIST('./data',
download=True,
train=True,
transform=transform)
testset = torchvision.datasets.FashionMNIST('./data',
download=True,
train=False,
transform=transform)
1.2 自備圖片
若要實現自有文件圖片,需要實現一個繼承torch.utils.data.Dataset的類.這裏dataset有兩種實現方式:
- map-style(類似數組)
需要實現兩個數組函數__getitem__()和__len__()。 - Iterable-style(類似指針)
這裏需要實現迭代函數__iter()__。
下例實現map-style()函數,在該函數中可以通過索引把圖像數據轉換,返回爲tensor數據.
import torch.utils.data as data
class DatasetFromFolder(data.Dataset):
def __init__(self):
super().__init__()
self.path = 'data/pose'#指定自己的路徑
self.image_filenames = [x for x in listdir(self.path)]
def __getitem__(self, index):
a = Image.open(join(self.path, self.image_filenames[index])).convert('L')
a = a.resize((64, 64), Image.BICUBIC)
a = transforms.ToTensor()(a)
return a
def __len__(self):
return len(self.image_filenames)
2.預處理數據
就是加載數據,這裏需要定義一個DataLoader類並設置必要參數,如一批數據batch的數量,是否隨機,
pose = DatasetFromFolder()
train_loader = torch.utils.data.DataLoader(
dataset=pose,
batch_size=25,
shuffle=False,
num_workers=0,
pin_memory=True,#用Nvidia GPU時生效
drop_last=True
)
3.測試
通過迭代train_loader類,來每次輸出一個batch,如:
for i, x in enumerate(train_loader):
print(i)
print(x.shape)
#torchvision.utils.save_image(x, './pose-img/%d.jpg'%(i), nrow=5)
4.其他
如果在使用datasets.ImageFolder(path)時,出現 'Found 0 files in subfolders of: xxx’這個錯誤,還是乖乖用繼承上述類實現加載自身數據吧。