from scratch implement crnn using pytorch :讀取訓練數據

知識點:

class Person(object):
    def __init__(self, name, gender):
        self.name = name
        self.gender = gender

    def __call__(self, friend):
        print 'My name is %s...' % self.name
        print 'My friend is %s...' % friend

現在可以對 Person 實例直接調用:

>>> p = Person('Bob', 'male')
>>> p('Tim')
My name is Bob...
My friend is Tim...
單看 p('Tim') 你無法確定 p 是一個函數還是一個類實例,所以,在Python中,函數也是對象,對象和函數的區別並不顯著
  • opencv中圖像的座標

pyopencv 函數 def resize(src, dsize, dst=None, fx=None, fy=None, interpolation=None): 故參數dsize輸入格式應該是(width,height)

在文檔中解釋是:

參數: inplace-選擇是否進行覆蓋運算

意思是是否將得到的值計算得到的值覆蓋之前的值,比如:x = x +1即對原值進行操作,然後將得到的值又直接複製到該值中,而不是覆蓋運算的例子如:= x + 1x = y這樣就需要花費內存去多存儲一個變量y,所以nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True)
意思就是對從上層網絡Conv2d中傳遞下來的tensor直接進行修改,這樣能夠節省運算內存,不用多存儲其他變量

  • 各個深度框架讀入圖像的順序

N: batch;

C: channel

H: height

W: width

Caffe 的Blob通道順序是:NCHW;

Tensorflow的tensor通道順序:默認是NHWC, 也支持NCHW,使用cuDNN會更快;

Pytorch中tensor的通道順序:NCHW

TensorRT中的tensor 通道順序: NCHW

  • pytorch加載數據

  常用到的工具有 torchvision 裏的 transforms

torch.utils.data 裏的 Dataset,DataLoader

dataloader本質是一個可迭代對象,使用iter()訪問,不能使用next()訪問;

使用iter(dataloader)返回的是一個迭代器,然後可以使用next訪問;

也可以使用`for inputs, labels in dataloaders`進行可迭代對象的訪問;

一般我們實現一個datasets對象,傳入到dataloader中;然後內部使用yeild返回每一次batch的數據

Dataloader的處理邏輯是先通過Dataset類裏面的 __getitem__ 函數獲取單個的數據,然後組合成batch。使用上主要是重構dataset,必須繼承自torch.utils.data.Dataset,內部要實現兩個函數一個是__lent__用來獲取整個數據集的大小,一個是__getitem__用來從數據集中得到一個數據片段item。

 

  • SyntheticChineseStringDataset 

該數據集是中文識別數據集,包含360多萬張訓練圖片,5824個字符,場景比較簡單,圖片是白底黑字

圖片:

      

標籤:前一部分爲圖像名稱,後一部分數字爲圖片上字符對應的字符編碼

字符編碼:char_std_5990.txt

 

  • 使用pytorch讀取SyntheticChineseStringDataset 數據集:
import torch
import os
import cv2
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader

def readfile(fileName):
    res = []
    with open(fileName,'r') as f:
        lines = f.readlines()
        for line in lines:
            res.append(line.strip())
    dic = {}
    for line in res:
        part = line.split(' ')
        dic[part[0]] = part[1:]

    return dic

# 調整圖像大小和歸一化操作
class resizeAndNormalize():
    def __init__(self,size,interpolation=cv2.INTER_LINEAR):
        # 注意對於opencv,size的格式是(w,h)
        self.size = size
        self.interpolation = interpolation
        # ToTensor屬於類  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
        self.toTensor = transforms.ToTensor()

    def __call__(self, image):
        # (x,y) 對於opencv來說,圖像寬對應x軸,高對應y軸
        image = cv2.resize(image,self.size,interpolation=self.interpolation)
        #轉爲tensor的數據結構
        image = self.toTensor(image)
        #對圖像進行歸一化操作
        image = image.sub_(0.5).div_(0.5)
        return image

class CRNNDataSet(Dataset):
    def __init__(self,imageRoot,labelRoot):
        self.image_root = imageRoot
        self.image_dict = readfile(labelRoot)
        self.image_name = [fileName for fileName,_ in self.image_dict.items()]

    def __getitem__(self, index):
        image_path = os.path.join(self.image_root,self.image_name[index])
        keys = self.image_dict.get(self.image_name[index])
        label = [int(x) for x in keys]

        image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
        (height,width) = image.shape

        #由於crnn網絡輸入圖像的高爲32,故需要resize原始圖像的height
        size_height = 32
        ratio = 32/float(height)
        size_width = int(ratio * width)
        transform = resizeAndNormalize((size_width,size_height))
        #圖像預處理
        image = transform(image)
        #標籤格式轉換爲IntTensor
        label = torch.IntTensor(label)

        return image,label

    def __len__(self):
        return len(self.image_name)

trainData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
                          labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_train.txt")

trainLoader = DataLoader(dataset=trainData,batch_size=1,shuffle=True)

for i,(data,label) in enumerate(trainLoader):
    print(i)
    print(data.shape)
    print(label)

 

 


 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章