Pytorch: Dataloader的一些使用心得
這篇博文不講原理,只講一些使用方法和技巧。所有提供的信息僅供參考,不要當作金科玉律。
基本程序框架
首先給出講述的時候使用的基本程序框架。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
class My_Dataset(Dataset):
def __init__(self, list1, array2):
self.len = len(list1)
self.x_data = list1 # something support indexing, like a list, length = 16
self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
# padding unequal length sequences
def collate_fn(batch_data):
return batch_data
# train dataloader & test dataloader
list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))
my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
batch_size = 4,
collate_fn = collate_fn)
從dataloader獲取數據
注意這個函數:
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
這代表,如果你用下標索引i從dataloader中取出值,返回值將會是一個長度爲2的元組,下標爲0的是list1[i]
(即第i+1個字母),下標爲1的是array2[i]
(即一個size = (4, 5)
的tensor)。暫且稱這種形式的數據爲data[i]
。
此時如果你運行如下指令:
for batch_data in enumerate(my_dataloader):
# show batch_data
batch_data是一個長度爲2的元組,下標爲0的是這個batch的序號(在以上的程序裏面是0~3),下標爲1的是一個長度爲4(batch_size)的support indexing的對象,這個對象的每個元素就是對應batch中應該包含的幾個data[i]
,比如第0個batch的這個列表中的元素就分別是data[0],..data[3]
。至於data[i]
則是剛纔說的由兩項數據所構成的元組。
在這裏,下標爲1的對象是一個列表。而如果數據本身就是一個tensor的話,這裏會給一個第一維維度爲batch_size,其他維維度數對應的tensor.
此時如果你運行如下指令:
for batch_index, batch_data in enumerate(train_loader):
# show data
這裏的batch_index對應元組的下標爲0的元素,即這個batch的序號(在以上的程序裏面是0~3);batch_data對應上面的列表(support indexing的對象)。顯然這種更細緻的處理是更常用的。
對於以上講的兩點,讀者可以直接跑一下附錄1所示的程序來獲得直觀感受。
collate_fn的使用
在從dataloader中讀取數據時,可以通過collate_fn
做處理,使讀取的數據符合要求。
讓我們審視這個函數:
def collate_fn(batch_data):
return batch_data
這裏輸入的batch_data就是上一節那個以batch_size爲長度,以對應位置的data[i]
爲元素的列表。如果要取得元素之後進行特定處理,可以在這個函數裏面操作;這個函數的返回值會代替原來那個列表的位置。可以運行附錄2的代碼獲得直觀感受。
collate_fn的使用實例
在自然語言處理中,可能要把不等長的tensor padding 成等長,這個步驟可以在collate_fn裏面做。舉個例子,下面的這個函數從不等長Tensor的列表生成一個padding成等長的高維tensor.
def collate_fn(data):
# self.data: list of tensors of different length
# data:[x[0], x[1], ..], x[0].shape = (20, 128), x[1].shape = (30, 128)
# x[2].shape = (28, 128), x[3].shape = (25, 128)
data.sort(key=lambda data: len(data[0]), reverse=True) # 按照序列長度降序排列
seq_len_list = [elem.shape[0] for elem in data]
data = pad_sequence(data, batch_first=True, padding_value=0)
seq_len_list = torch.Tensor(seq_len_list)
return data_batch, seq_len_list
# data_batch.shape = [4, 30, 128], seq_len_list = [20, 30, 28, 25]
函數的返回值包括合併的高維tensor和每個小tensor的實際長度,方便後續處理使用。
附錄
附錄1
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
torch.manual_seed(314)
class My_Dataset(Dataset):
def __init__(self, list1, array2):
self.len = len(list1)
self.x_data = list1 # something support indexing, like a list, length = 16
self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
# padding unequal length sequences
def collate_fn(batch_data):
return batch_data
# train dataloader & test dataloader
list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))
my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
batch_size = 4,
collate_fn = collate_fn)
for batch_data in enumerate(my_dataloader):
# show batch_data
print("New Batch")
print(type(batch_data), len(batch_data), batch_data[0], type(batch_data[1]))
print(len(batch_data[1]), type(batch_data[1][0]))
print(batch_data[1][0][0], type(batch_data[1][0][1]), batch_data[1][0][1].shape)
for batch_index, batch_data in enumerate(my_dataloader):
# show batch_data
print("Batch", batch_index)
for i in range(len(batch_data)):
print(type(batch_data[i]), len(batch_data[i]))
print(batch_data[i][0], type(batch_data[i][1]), batch_data[i][1].shape)
附錄2
...
my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
batch_size = 4,
collate_fn = collate_fn)
for batch_index, batch_data in enumerate(my_dataloader):
# show batch_data
print("Batch", batch_index)
print(batch_data)