OCR模型訓練

{"type":"doc","content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"OCR 從流程上包括兩步:"},{"type":"text","marks":[{"type":"strong"}],"text":"文本檢測"},{"type":"text","text":"和"},{"type":"text","marks":[{"type":"strong"}],"text":"文本識別"},{"type":"text","text":",即將圖片輸入到文本檢測算法中得到一個個的文本框,將每個文本框分別送入到文本識別算法中得到識別結果。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"1. 基於深度學習的文本檢測算法大致分爲兩類:"},{"type":"text","marks":[{"type":"italic"}],"text":"基於候選框迴歸的算法*"},{"type":"text","marks":[{"type":"italic"},{"type":"italic"}],"text":"和*"},{"type":"text","marks":[{"type":"italic"},{"type":"italic"},{"type":"italic"}],"text":"基於分割的算法。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"bulletedlist","content":[{"type":"listitem","attrs":{"listStyle":null},"content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"基於候選框迴歸的文本檢測"},{"type":"text","text":",是源於目標檢測算法,然後結合文本框的特點改造而成的,包括 CTPN、EAST 和 Seglink 算法等。CTPN 是基於 faster RCNN 改進的算法,在 CNN 後加入 RNN 網絡,主要思想是把文本行切分成小的細長矩形進行檢測再拼接起來;SegLink 算法的檢測思路與 CTPN 類似,也是先檢測文本行的小塊然後拼起來,但網絡結構上採取了 SSD 的思路,在多個特徵圖尺度上進行文本檢測,然後將多尺度的結果融合起來,另外輸出中加入了角度信息的迴歸;EAST 算法,它是直接回歸的整個文本行的座標,而不是細長矩形拼接,網絡結構上利用了 Unet 的上採樣結構來提取特徵,融入了淺層和深層的信息,並且在輸出層迴歸了角度信息,可以檢測斜框。"}]}]}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"bulletedlist","content":[{"type":"listitem","attrs":{"listStyle":null},"content":[{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"基於分割的文本檢測"},{"type":"text","text":",其基本思路是通過分割網絡進行像素級別的語義分割,再基於分割的結果構建文本行,包括 PixelLink、Psenet 和 Craft 算法等。PixelLink 算法,網絡結構上採用 FCN 提取特徵,直接通過實例分割結果中提取文本位置,輸出的特徵圖包括像素分類特徵圖和像素 link 特徵圖。Psenet 算法,網絡結構上採用 FPN 特徵金字塔提取特徵,對每個分割區域預測出多個分割結果,然後提出一種新穎的漸進擴展算法,將多個分割的結果進行融合。Craft 算法,網絡結構上採用 UNet 的結構,輸出的特徵圖包括 Region score 特徵圖和像素 Affinity score 特徵圖,另外特徵圖中使用了高斯函數,將預測像素點分類的問題轉成了像素點的迴歸問題,能更好的適應文字沒有嚴格包圍邊界的特點。"}]}]}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"2. 基於深度學習的文本識別算法則相對較爲統一,一般都採用CNN+RNN+CTC 的結構,俗稱 CRNN 結構,因爲這種結構的識別效果很好,且泛化性好,工業上大多都用的這種結構,然後在該框架上做一些改進,如更換 CNN 主幹網絡,縮減卷積層以提高速度縮減空間,或者改進 RNN 加入 Attention 結構等。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"本文主要介紹了我們在生產上使用的文本檢測和文本識別算法。算法的訓練流程一般包括以下步驟:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"1. 準備訓練數據"},{"type":"text","text":",有的是需要標註(如文本檢測中),有的主要是造數據(如文本識別中);"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"2. 定義算法網絡"},{"type":"text","text":",這裏主要是明確輸入和輸出;"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"3. 準備好 batch 數據集"},{"type":"text","text":",這裏主要是處理輸入的圖片和標籤數據,標籤數據結構與第 2 步中的網絡輸出對應,例如 craft 要進行高斯函數計算等,而文本識別中則無需處理,直接將造好的數據輸入即可;"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"4. 定義"},{"type":"text","text":" "},{"type":"text","marks":[{"type":"strong"}],"text":"loss"},{"type":"text","text":",優化器和學習率等參數;"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"5. 訓練"},{"type":"text","text":",這裏主要是定義每批次數據訓練的操作策略,如保存策略,日誌策略,測試策略等。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"OCR文本檢測"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"我們在文本定位中採用的是 Craft 算法,它是一種基於分割的算法,無需進行大量候選框的迴歸,也無需進行 NMS 後處理,因此極大提升了速度,並且它是字符級別的文本檢測器,定位的是字符,對於尺寸縮放不敏感,無需多尺度訓練和預測來解決尺度方差問題,最後其泛化性能也能達到 SOTA 的水平。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"1、訓練數據標註"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"該方法是基於分割的,背景文字是指的原本就在票據上的文字,如“姓名”、“出生年月”等文字,前景文字是指的待識別的文字,也就是用戶後填進去的內容。標註步驟就是將這些文字框出來,標上相應的類別。我們採用自己開發的標註工具,這裏也可以使用開源的 labelme 工具,生成的標註文件如下所示,第一行是圖片所在路徑,從第二行開始就是座標框信息,最後一位是類別。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/8e\/75\/8e0aaf835b47ef437d0b4f2eba40d975.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"2、網絡設計"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"下圖是網絡結構圖,整體採用了 Unet 的主結構,主幹網絡用的 vgg16,輸入圖片首先經過 vgg16 後,接 UNet 的上採樣結構,其作用是使得深層和淺層的特徵圖進行拼接作爲輸出。然後再接一系列的卷積操作,充分提取特徵。最後輸出的特徵圖包括 Region score 特徵圖和像素 Affinity score 特徵圖。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/33\/e2\/33406df677832d5719c71fc76cf393e2.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"網絡的代碼如下所示:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"class CRAFT(nn.Module):\n def __init__(self, pretrained=False, freeze=False, phase='test'):\n super(CRAFT, self).__init__()\n\n \"\"\" Base network \"\"\"\n self.basenet = vgg16_bn(pretrained, freeze)\n\n \"\"\" 固定部分參數,用於遷移學習\"\"\"\n if phase == 'train':\n for p in self.parameters():\n p.requires_grad=False\n\n \"\"\" U network \"\"\"\n self.upconv1 = double_conv(1024, 512, 256)\n self.upconv2 = double_conv(512, 256, 128)\n self.upconv3 = double_conv(256, 128, 64)\n self.upconv4 = double_conv(128, 64, 32)\n\n num_class = 2\n self.conv_cls = nn.Sequential(\n nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),\n nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),\n nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),\n nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),\n nn.Conv2d(16, num_class, kernel_size=1),\n )\n\n init_weights(self.upconv1.modules())\n init_weights(self.upconv2.modules())\n init_weights(self.upconv3.modules())\n init_weights(self.upconv4.modules())\n init_weights(self.conv_cls.modules())\n\n def forward(self, x):\n \"\"\" Base network \"\"\"\n sources = self.basenet(x)\n\n \"\"\" U network \"\"\"\n y = torch.cat([sources[0], sources[1]], dim=1)\n y = self.upconv1(y)\n\n y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)\n y = torch.cat([y, sources[2]], dim=1)\n y = self.upconv2(y)\n\n y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)\n y = torch.cat([y, sources[3]], dim=1)\n y = self.upconv3(y)\n\n y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)\n y = torch.cat([y, sources[4]], dim=1)\n feature = self.upconv4(y)\n\n y = self.conv_cls(feature)\n\n return y.permute(0,2,3,1), feature\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"我們也從代碼中把網絡結構打印出來,可以看到最後一層輸出的結構,最終網絡的輸出結構是(batchsize, 2, w, h),即通道數爲 2 的特徵圖。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/59\/da\/59703d34abe6b04b53b26acf403b4cda.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"3、訓練標籤生成"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"圖片的標籤數據包括 region_score 和 affinity_score 兩個特徵圖,region_score 表示給定的像素是字符中心的概率,affinity_score 表示相鄰兩個字符中間空白區域中心的概率。其特徵圖不像二進制分割圖那樣用離散方式標記每個像素,本文使用高斯熱圖對字符中心的概率進行編碼,將分類問題轉化爲迴歸問題,另外採用高斯熱度圖的好處是它能很好地處理沒有嚴格包圍的邊界區域,因爲文字不像傳統目標檢測的物體,它沒有明確的輪廓邊界。生成高斯熱圖的流程圖如下所示:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/7d\/c5\/7d9883c77e8b0be5b11be3dcb9932fc5.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"高斯熱圖的生成代碼如下所示:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"def generate_transformed_gaussian_kernel(h, w, points):\n '''\n 使用透視變換的高斯核建模region或affinity\n h:圖像的高\n w:圖像的寬\n points:維度(4,2)\n '''\n # 生成高斯核\n minX, minY = points[0]\n maxX, maxY = points[0]\n for i in range(1,4):\n minX = min(points[i][0],minX)\n minY = min(points[i][1],minY)\n maxX = max(points[i][0],maxX)\n maxY = max(points[i][1],maxY)\n kernel_w = int((maxX - minX + 1) \/\/ 2 * 2)\n kernel_h = int((maxY - minY + 1) \/\/ 2 * 2)\n\n kernel_size = 31\n kernel = np.zeros((kernel_size, kernel_size))\n kernel[kernel_size\/\/2, kernel_size\/\/2] = 1\n kernel = gaussian_filter(kernel, 10, mode='constant')\n\n kernel_size = max(kernel_h, kernel_w)\n kernel = cv2.resize(kernel,(kernel_size,kernel_size))\n\n # 將高斯核透視變換,座標(列,行)\n src = np.float32([(0,0),(0,kernel_size),(kernel_size,kernel_size),(kernel_size,0)]) # 左上,左下,右下,右上\n tgt = np.float32(points)\n M = cv2.getPerspectiveTransform(src, tgt)\n dst = cv2.warpPerspective(kernel, M, (w,h))\n\n # 轉換到[0.001,1]之間\n mini = dst[np.where(dst>0)].min()\n maxi = dst[np.where(dst>0)].max()\n h = 1\n l = 0.001 # 與預訓練模型的分佈保持一致\n dst[np.where(dst>0)] = ((h-l)*dst[np.where(dst>0)]-h*mini+l*maxi) \/ (maxi-mini)\n\n return dst\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"調用上述函數計算好 region_map 和 affinity_map 的結果,用 .npy 格式保存起來,然後在數據類中調用,自定義數據類的代碼如下所示:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"class MyDataset(Dataset):\n def __init__(self, root):\n self.root = root\n self.imglist = [f.split('.')[0] for f in os.listdir(os.path.join(root, 'img'))]\n\n def __getitem__(self, index):\n # read img, region_map, affinity_map\n img_path = os.path.join(self.root, 'img', self.imglist[index]+'.jpg')\n# img = plt.imread(img_path)\n img = np.array(plt.imread(img_path))\n\n region_path = os.path.join(self.root, 'region', \n self.imglist[index].split('_')[0]+'_region_'\n +self.imglist[index].split('_')[1]+'.npy')\n region_map = np.load(region_path).astype(np.float32)\n\n affinity_path = os.path.join(self.root, 'affinity', \n self.imglist[index].split('_')[0]+'_affinity_'\n +self.imglist[index].split('_')[1]+'.npy')\n affinity_map = np.load(affinity_path).astype(np.float32)\n\n # 保證圖像長和寬是2的倍數\n h, w, c = img.shape\n if h % 2 != 0 or w % 2 != 0:\n h = int(h \/\/ 2 * 2)\n w = int(w \/\/ 2 * 2)\n img = cv2.resize(img, (w, h))\n region_map = cv2.resize(region_map, (w, h))\n affinity_map = cv2.resize(affinity_map, (w, h))\n\n # preprocess\n img = normalizeMeanVariance(img)\n img = torch.from_numpy(img).permute(2, 0, 1) # [h, w, c] to [c, h, w]\n\n region_map = cv2.resize(region_map, (w\/\/2, h\/\/2))\n region_map = torch.tensor(region_map).unsqueeze(2)\n affinity_map = cv2.resize(affinity_map, (w\/\/2, h\/\/2))\n affinity_map = torch.tensor(affinity_map).unsqueeze(2)\n gt_map = torch.cat((region_map,affinity_map), dim=2)\n\n return {'img':img, 'gt':gt_map}\n\n\n def __len__(self):\n return len(self.imglist)\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"4、損失函數設計"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"由於輸出的特徵圖採用高斯函數構建,因此分割的損失函數也由交叉熵損失函數換成了迴歸用的 MSE 損失函數,優化器選用經典的 SGD。代碼如下所示,"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"criterion=nn.MSELoss(size_average=False).to(device)\noptimizer=torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),1e-7,\n momentum=0.95,\n weight_decay=0)\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"5、網絡訓練設計"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"由於文字定位需要的數據量極大,而真實數據集通常很少,標註也較困難,這裏使用的是 finetune 方式,即載入預訓練權值然後微調訓練的方式,用較少的訓練集就能達到很好的效果。代碼如下所示:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"if __name__ == '__main__':\n \"\"\"參數設置\"\"\"\n device = 'cuda' # cpu 或 cuda\n dataset_path = '.\/data' # 自己數據集的路徑\n pretrained_path = '.\/pretrained\/craft_mlt_25k.pth' # 預訓練模型的存放路徑\n model_path = '.\/models' # 現在訓練的模型要存儲的路徑\n\n\n dataset = MyDataset(dataset_path)\n loader = DataLoader(dataset, batch_size=1, shuffle=True)\n net = CRAFT(phase='train').to(device)\n net.load_state_dict(copyStateDict(torch.load(pretrained_path, map_location=device)))\n criterion=nn.MSELoss(size_average=False).to(device)\n optimizer=torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),1e-7,\n momentum=0.95,\n weight_decay=0)\n if not os.path.exists(model_path):\n os.mkdir(model_path)\n\n for epoch in range(500):\n epoch_loss = 0\n for i, data in enumerate(loader):\n img = data['img'].to(device)\n gt = data['gt'].to(device)\n\n # forward\n y, _ = net(img)\n loss = criterion(y, gt)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n epoch_loss += loss.detach()\n print('epoch loss_'+str(epoch),':',epoch_loss\/len(loader))\n torch.save(net.state_dict(), os.path.join(model_path,str(epoch)+'.pth'))\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"6、測試結果"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"測試結果如下圖所示,其中左圖是預測出的高斯熱圖,爲了便於進行一系列的圖像操作,所以統一 resize 成了正方形,右圖是由高斯熱圖轉化出的矩形框圖,也就是最終可以放入識別模型中的切片框圖,爲了保護隱私信息做了模糊處理。可以看到,票據中的字被不同顏色的框給框出來了,並且分好了類別,其中紅色的爲背景字,藍色的爲前景字。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/1a\/8e\/1a5c6c85638b3860a657721c5739778e.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"heading","attrs":{"align":null,"level":2},"content":[{"type":"text","marks":[{"type":"strong"}],"text":"OCR 文本識別"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"我們在文本識別中採用的是 CRNN 算法,它結構非常簡單,就是 CNN+RNN+CTC 的結構,CNN 用來提取圖像特徵,RNN 用來提取文字的序列特徵,CTC 用來對齊輸出與標籤來計算 loss。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"1、訓練數據生成"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"在之前的《 OCR 數據處理篇》已經介紹過如何生成訓練數據,生成的圖像如下所示。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/37\/b2\/37b2de44f0a54yyd75915994e26d4eb2.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"訓練數據的生成有兩種方式,離線生成和在線生成,離線方式意思是先將圖片數據生成好存入硬盤中,然後讀取;在線方式是指:在每個 batch 的訓練開始前動態的生成訓練圖片,從而不會保存圖片。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"OCR 識別模型的訓練需要大量的數據,通常需要的數據量是字符集的 1000 倍,例如要訓練一個能認 5000 字的模型,至少需要 500 萬條數據才能訓練,這麼多的小圖片數據存入內存中,一是佔用內存,二是小文件的讀取會非常慢。因此在 OCR 識別模型的訓練時通常會採用在線的方法。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"生成訓練數據前需要先準備字符集,將字符集處理成如下的 txt 文件,一行爲一個字。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/49\/e1\/49b67caea17a790a8cf40b55deb04ae1.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"在線生成數據的代碼如下,主要就是自定義一個pytorch的Dataset類,它自帶的__getitem__方法是個迭代器, 每個 batch 載入數據時會自動調用該方法,取出 batch_size 大小的數據,然後定義好字符集、字體、背景、顏色等信息就可以了,也可以制定一些隨機策略生成。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"char_name = 'chinese_word.txt'\ndata_set = Generator(cfg.word.get_charset(char_name), args.direction, char_name = char_name, word_times = 1000)\nif args.distributed:\n val_sampler = torch.utils.data.distributed.DistributedSampler(val_set)\nelse:\n val_sampler = torch.utils.data.RandomSampler(val_set)\ndata_loader = DataLoader(data_set, batch_size=args.batch_size, sampler=train_sampler,\n num_workers=args.workers)\n\nclass Generator(Dataset):\n def __init__(self, alpha, direction='horizontal', char_name = 'chinese_word.txt', word_times = 100):\n \"\"\"\n :param alpha: 所有字符\n :param direction: 文字方向:horizontal|vertical\n \"\"\"\n super(Generator, self).__init__()\n self.alpha = alpha\n self.direction = direction\n self.alpha_list = list(alpha)\n self.min_len = 5\n self.max_len_list = [16, 19, 24, 26]\n self.max_len = max(self.max_len_list)\n self.font_size_list = [30, 25, 20, 18]\n self.font_path_list = list(FONT_CHARS_DICT.keys())\n self.font_list = [] # 二位列表[size,font]\n self.word_times = word_times\n for size in self.font_size_list:\n self.font_list.append([ImageFont.truetype(font_path, size=size)\n for font_path in self.font_path_list])\n if self.direction == 'horizontal':\n self.im_h = 32\n self.im_w = 512\n else:\n self.im_h = 512\n self.im_w = 32\n def get_allchar():\n f = codecs.open(os.path.join('.\/data', char_name),\n mode='r', encoding='utf-8')\n lines = f.readlines()\n f.close()\n charlist = [l.strip() for l in lines]\n return charlist\n self.charlist = get_allchar()\ndef __getitem__(self, item):\n image, indices, target_len = self.gen_image()\n if self.direction == 'horizontal':\n image = np.transpose(image[:, :, np.newaxis], axes=(2, 1, 0)) # [H,W,C]=>[C,W,H]\n else:\n image = np.transpose(image[:, :, np.newaxis], axes=(2, 0, 1)) # [H,W,C]=>[C,H,W]\n # 標準化\n image = image.astype(np.float32) \/ 255.\n image -= 0.5\n image \/= 0.5\n target = np.zeros(shape=(self.max_len,), dtype=np.long)\n target[:target_len] = indices\n if self.direction == 'horizontal':\n input_len = self.im_w \/\/ 4 - 3\n else:\n input_len = self.im_w \/\/ 16 - 1\n return image, target, input_len, target_len\ndef __len__(self):\nreturn len(self.alpha) * self.word_times\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"2、網絡設計"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"是由 cnn+rnn 組成,代碼如下所示:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"class CRNN(nn.Module):\n def __init__(self, num_classes, **kwargs):\n super(CRNN, self).__init__(**kwargs)\n self.cnn = nn.Sequential(OrderedDict([\n ('conv_block_1', _ConvBlock(1, 64)), # [B,64,W,32]\n ('max_pool_1', nn.MaxPool2d(2, 2)), # [B,64,W\/2,16]\n ('conv_block_2', _ConvBlock(64, 128)), # [B,128,W\/2,16]\n ('max_pool_2', nn.MaxPool2d(2, 2)), # [B,128,W\/4,8]\n ('conv_block_3_1', _ConvBlock(128, 256)), # [B,256,W\/4,8]\n ('conv_block_3_2', _ConvBlock(256, 256)), # [B,256,W\/4,8]\n ('max_pool_3', nn.MaxPool2d((2, 2), (1, 2))), # [B,256,W\/4,4]\n ('conv_block_4_1', _ConvBlock(256, 512, bn=True)), # [B,512,W\/4,4]\n ('conv_block_4_2', _ConvBlock(512, 512, bn=True)), # [B,512,W\/4,4]\n ('max_pool_4', nn.MaxPool2d((2, 2), (1, 2))), # [B,512,W\/4,2]\n ('conv_block_5', _ConvBlock(512, 512, kernel_size=2, padding=0)) # [B,512,W\/4,1]\n ]))\n self.rnn1 = nn.GRU(512, 256, batch_first=True, bidirectional=True)\n self.rnn2 = nn.GRU(512, 256, batch_first=True, bidirectional=True)\n self.transcript = nn.Linear(512, num_classes)\n\ndef forward(self, x):\n \"\"\"\n :param x: [B, 1, W, 32]\n :return: [B, W,num_classes]\n \"\"\"\n x = self.cnn(x) # [B,512,W\/16,1]\n x = torch.squeeze(x, 3) # [B,512,W]\n x = x.permute([0, 2, 1]) # [B,W,512]\n x, h1 = self.rnn1(x)\n x, h2 = self.rnn2(x, h1)\n x = self.transcript(x)\n return x\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"自定義一個 pytorch 的 Module 類,cnn 層這裏用的結構是 vgg16,ConvBlock 結構如下所示,就是卷積,batchNormalization 加 relu。這裏也可以根據需要換成 resnet 或者 densenet 主幹網絡。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"class _ConvBlock(nn.Sequential):\n def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bn=False):\n super(_ConvBlock, self).__init__()\n self.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))\n if bn:\n self.add_module('norm', nn.BatchNorm2d(out_channels))\n self.add_module('relu', nn.ReLU(inplace=True))\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"然後 rnn 層是一個雙向的 rnn,這裏爲了加速用的 gru 代替 lstm,最後接一個線性層,最終輸出爲(batchsize, unit, class_num),其中 unit 是根據識別的切片長度不同而變化的,class_num 是字符集的個數,因爲最後是計算 softmax。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"3、損失函數的設計"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"在上述 rnn 的輸出單元 units 後,因爲每個切片的字符數量不同,字體大小樣式不同,導致每個 unit 的輸出與結果的字符並不是一一對應的,因此採用了 CTC_loss 的損失函數。用的 torch.nn.CTCLoss(),是 pytorch 自帶的函數。代碼如下所示:"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"text"},"content":[{"type":"text","text":"model = crnn.CRNN(len(data_set.alpha))\nmodel = model.to(device)\ncriterion = CTCLoss()\ncriterion = criterion.to(device)\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"4、網絡訓練設計"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"每個 epoch 的訓練過程如下:輸入圖片—》預測—》計算損失—》反向傳播更新參數—》保存模型,pytorch 可以動態的把整個訓練步驟用代碼形式寫出來,因此很容易編寫和調試中間步驟。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"def train_one_epoch(model, criterion, optimizer, data_loader, val_set, device, epoch, args):\n model.train()\n epoch_loss = 0.0\n # for image, target, input_len, target_len in tqdm(data_loader):\n for i, sample_batched in enumerate(data_loader):\n image, target, input_len, target_len = sample_batched\n image = image.to(device)\n # print(target, target_len, input_len)\n outputs = model(image.to(torch.float32)) # [B,N,C]\n outputs = torch.log_softmax(outputs, dim=2)\n outputs = outputs.permute([1, 0, 2]) # [N,B,C]\n loss = criterion(outputs[:], target, input_len, target_len)\n # 梯度更新\n model.zero_grad()\n loss.backward()\n optimizer.step()\n # 當前輪的loss\n epoch_loss += loss.item() * image.size(0)\n # 每訓練一個batch打印一次 loss 和 acc\n if i % 100 == 0:\n print('[epoch:%d, %d | %d] Loss: %.03f'\n % (epoch, i, len(data_loader), epoch_loss \/ (i + 1)))\n if np.isnan(loss.item()):\n print(target, input_len, target_len)\n epoch_loss = epoch_loss \/ len(data_loader.dataset)\n # 打印日誌,保存權重\nprint('Epoch: {}\/{} loss: {:03f}'.format(epoch + 1, args.epochs, epoch_loss))\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"heading","attrs":{"align":null,"level":3},"content":[{"type":"text","text":"5、結果展示"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"測試代碼如下,主要就是模型的加載調用和預測結果的處理。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"codeblock","attrs":{"lang":"java"},"content":[{"type":"text","text":"def inference_image(net, alpha, image_path):\n image = load_image(image_path)\n image = torch.FloatTensor(image)\n predict = net(image)[0].detach().numpy() # [W,num_classes]\n label = np.argmax(predict[:], axis=1)\n label = [alpha[class_id] for class_id in label]\n label = [k for k, g in itertools.groupby(list(label))]\n label = ''.join(label).replace(' ', '')\nreturn label\n\ndef main(args):\n alpha = cfg.word.get_charset('chinese_word.txt')\n if args.direction == 'horizontal':\n net = crnn.CRNN(num_classes=len(alpha))\n else:\n net = crnn.CRNNV(num_classes=len(alpha))\n net.load_state_dict(torch.load(args.weight_path, map_location='cpu')['model'])\n net.eval()\n # load image\n if args.image_dir:\n image_path_list = [os.path.join(args.image_dir, n) for n in os.listdir(args.image_dir)]\n image_path_list.sort()\n for image_path in image_path_list:\n label = inference_image(net, alpha, image_path)\n print(\"image_path:{},label:{}\".format(image_path, label))\n else:\n label = inference_image(net, alpha, args.image_path)\n print(\"image_path:{},label:{}\".format(args.image_path, label))\n"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"測試準確率在 98% 左右,測試的樣例結果如下所示,3 張全部識別正確。"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"image","attrs":{"src":"https:\/\/static001.infoq.cn\/resource\/image\/a0\/58\/a0eba6ayy93864aa62f2ffaaafb5b958.png","alt":null,"title":"","style":[{"key":"width","value":"75%"},{"key":"bordertype","value":"none"}],"href":"","fromPaste":false,"pastePass":false}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null}},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"文章轉載自:金科優源匯(ID:jkyyh2020)"}]},{"type":"paragraph","attrs":{"indent":0,"number":0,"align":null,"origin":null},"content":[{"type":"text","text":"原文鏈接:"},{"type":"link","attrs":{"href":"https:\/\/mp.weixin.qq.com\/s\/WnrNLNLb5X0VXidhZgkDrQ","title":"xxx","type":null},"content":[{"type":"text","text":"OCR模型訓練"}]}]}]}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章