知識點:
-
python特殊函數 __call__() 實現類變成一個可調用對象
class Person(object):
def __init__(self, name, gender):
self.name = name
self.gender = gender
def __call__(self, friend):
print 'My name is %s...' % self.name
print 'My friend is %s...' % friend
現在可以對 Person 實例直接調用:
>>> p = Person('Bob', 'male')
>>> p('Tim')
My name is Bob...
My friend is Tim...
單看 p('Tim') 你無法確定 p 是一個函數還是一個類實例,所以,在Python中,函數也是對象,對象和函數的區別並不顯著
-
opencv中圖像的座標
pyopencv 函數 def resize(src, dsize, dst=None, fx=None, fy=None, interpolation=None): 故參數dsize輸入格式應該是(width,height)
在文檔中解釋是:
參數: inplace-選擇是否進行覆蓋運算
意思是是否將得到的值計算得到的值覆蓋之前的值,比如:x = x +1
即對原值進行操作,然後將得到的值又直接複製到該值中,而不是覆蓋運算的例子如:
y = x + 1
x = y
這樣就需要花費內存去多存儲一個變量y,所以nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True)
意思就是對從上層網絡Conv2d中傳遞下來的tensor直接進行修改,這樣能夠節省運算內存,不用多存儲其他變量
-
各個深度框架讀入圖像的順序
N: batch;
C: channel
H: height
W: width
Caffe 的Blob通道順序是:NCHW;
Tensorflow的tensor通道順序:默認是NHWC, 也支持NCHW,使用cuDNN會更快;
Pytorch中tensor的通道順序:NCHW
TensorRT中的tensor 通道順序: NCHW
-
pytorch加載數據
常用到的工具有 torchvision 裏的 transforms
torch.utils.data 裏的 Dataset,DataLoader
dataloader本質是一個可迭代對象,使用iter()訪問,不能使用next()訪問;
使用iter(dataloader)返回的是一個迭代器,然後可以使用next訪問;
也可以使用`for inputs, labels in dataloaders`進行可迭代對象的訪問;
一般我們實現一個datasets對象,傳入到dataloader中;然後內部使用yeild返回每一次batch的數據
Dataloader的處理邏輯是先通過Dataset類裏面的 __getitem__
函數獲取單個的數據,然後組合成batch。使用上主要是重構dataset
,必須繼承自torch.utils.data.Dataset
,內部要實現兩個函數一個是__lent__
用來獲取整個數據集的大小,一個是__getitem__
用來從數據集中得到一個數據片段item。
- SyntheticChineseStringDataset
該數據集是中文識別數據集,包含360多萬張訓練圖片,5824個字符,場景比較簡單,圖片是白底黑字
圖片:
標籤:前一部分爲圖像名稱,後一部分數字爲圖片上字符對應的字符編碼
字符編碼:char_std_5990.txt
- 使用pytorch讀取SyntheticChineseStringDataset 數據集:
import torch
import os
import cv2
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
def readfile(fileName):
res = []
with open(fileName,'r') as f:
lines = f.readlines()
for line in lines:
res.append(line.strip())
dic = {}
for line in res:
part = line.split(' ')
dic[part[0]] = part[1:]
return dic
# 調整圖像大小和歸一化操作
class resizeAndNormalize():
def __init__(self,size,interpolation=cv2.INTER_LINEAR):
# 注意對於opencv,size的格式是(w,h)
self.size = size
self.interpolation = interpolation
# ToTensor屬於類 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
self.toTensor = transforms.ToTensor()
def __call__(self, image):
# (x,y) 對於opencv來說,圖像寬對應x軸,高對應y軸
image = cv2.resize(image,self.size,interpolation=self.interpolation)
#轉爲tensor的數據結構
image = self.toTensor(image)
#對圖像進行歸一化操作
image = image.sub_(0.5).div_(0.5)
return image
class CRNNDataSet(Dataset):
def __init__(self,imageRoot,labelRoot):
self.image_root = imageRoot
self.image_dict = readfile(labelRoot)
self.image_name = [fileName for fileName,_ in self.image_dict.items()]
def __getitem__(self, index):
image_path = os.path.join(self.image_root,self.image_name[index])
keys = self.image_dict.get(self.image_name[index])
label = [int(x) for x in keys]
image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
(height,width) = image.shape
#由於crnn網絡輸入圖像的高爲32,故需要resize原始圖像的height
size_height = 32
ratio = 32/float(height)
size_width = int(ratio * width)
transform = resizeAndNormalize((size_width,size_height))
#圖像預處理
image = transform(image)
#標籤格式轉換爲IntTensor
label = torch.IntTensor(label)
return image,label
def __len__(self):
return len(self.image_name)
trainData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_train.txt")
trainLoader = DataLoader(dataset=trainData,batch_size=1,shuffle=True)
for i,(data,label) in enumerate(trainLoader):
print(i)
print(data.shape)
print(label)