在前面的文章中,已經介紹了基於SSD使用自己的數據訓練目標檢測模型(見文章:手把手教你訓練自己的目標檢測模型),本文將基於另一個目標檢測模型YOLO,介紹如何使用自己的數據進行訓練。
YOLO(You only look once)是目前流行的目標檢測模型之一,目前最新已經發展到V3版本了,在業界的應用也很廣泛。YOLO的基本原理是:首先對輸入圖像劃分成7x7的網格,對每個網格預測2個邊框,然後根據閾值去除可能性比較低的目標窗口,最後再使用邊框合併的方式去除冗餘窗口,得出檢測結果,如下圖:
YOLO的特點就是“快”,但由於YOLO對每個網格只預測一個物體,就容易造成漏檢,對物體的尺度相對比較敏感,對於尺度變化較大的物體泛化能力較差。
本文的目標仍舊是在圖像中識別檢測出可愛的熊貓
基於YOLO使用自己的數據訓練目標檢測模型,訓練過程跟前面文章所介紹的基於SSD訓練模型一樣,主要步驟如下:
1、安裝標註工具
本案例採用的標註工具是labelImg,在前面的文章介紹訓練SSD模型時有詳細介紹了安裝方法(見文章:手把手教你訓練自己的目標檢測模型),在此就不再贅述了。
成功安裝後的labelImg標註工具,如下圖:
2、標註數據
使用labelImg工具對熊貓照片進行畫框標註,自動生成VOC_2007格式的xml文件,保存爲訓練數據集。操作方式跟前面的文章介紹訓練SSD模型的標註方法一樣(見文章:手把手教你訓練自己的目標檢測模型),在此就不再贅述了。
3、配置YOLO
(1)安裝Keras
本案例選用YOLO的最新V3版本,基於Keras版本。Keras是一個高層神經網絡API,以Tensorflow、Theano和CNTK作爲後端。由於本案例的基礎環境(見文章:AI基礎環境搭建)已經安裝了tensorflow,因此,Keras底層將會調用tensorflow跑模型。Keras安裝方式如下:
# 切換虛擬環境
source activate tensorflow
# 安裝keras-gpu版本
conda install keras-gpu
# 如果是安裝 keras cpu版本,則執行以下指令
#conda install keras
keras版本的yolo3還依賴於PIL工具包,如果之前沒安裝的,也要在anaconda中安裝
# 安裝 PIL
conda install pillow
(2)下載yolo3源代碼
在keras-yolo3的github上下載源代碼(https://github.com/qqwweee/keras-yolo3),使用git進行clone或者直接下載成zip壓縮文件。
(3)導入PyCharm
打開PyCharm,新建項目,將keras-yolo3的源代碼導入到PyCharm中
4、下載預訓練模型
YOLO官網上提供了YOLOv3模型訓練好的權重文件,把它下載保存到電腦上。下載地址爲https://pjreddie.com/media/files/yolov3.weights
5、訓練模型
接下來到了關鍵的步驟:訓練模型。在訓練模型之前,還有幾項準備工作要做。
(1)轉換標註數據文件
YOLO採用的標註數據文件,每一行由文件所在路徑、標註框的位置(左上角、右下角)、類別ID組成,格式爲:image_file_path x_min,y_min,x_max,y_max,class_id
例子如下:
這種文件格式跟前面製作好的VOC_2007標註文件的格式不一樣,Keras-yolo3裏面提供了voc格式轉yolo格式的轉換腳本 voc_annotation.py
在轉換格式之前,先打開voc_annotation.py文件,修改裏面的classes的值。例如本案例在voc_2007中標註的熊貓的物體命名爲panda,因此voc_annotation.py修改爲:
import xml.etree.ElementTree as ET
from os import getcwd
sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]
classes = ["panda"]
新建文件夾VOCdevkit/VOC2007,將熊貓的標註數據文件夾Annotations、ImageSets、JPEGImages放到文件夾VOCdevkit/VOC2007裏面,然後執行轉換腳本,代碼如下:
mkdir VOCdevkit
mkdir VOCdevkit/VOC2007
mv Annotations VOCdevkit/VOC2007
mv ImageSets VOCdevkit/VOC2007
mv JPEGImages VOCdevkit/VOC2007
source activate tensorflow
python voc_annotation.py
轉換後,將會自動生成yolo格式的文件,包括訓練集、測試集、驗證集。
(2)創建類別文件
在PyCharm導入的keras-yolo3源代碼中,在model_data目錄裏面新建一個類別文件my_class.txt,將標註物體的類別寫到裏面,每行一個類別,如下:
(3)轉換權重文件
將前面下載的yolo權重文件yolov3.weights轉換成適合Keras的模型文件,轉換代碼如下:
source activate tensorflow
python convert.py -w yolov3.cfg yolov3.weights model_data/yolo_weights.h5
(4)修改訓練文件的路徑配置
修改train.py裏面的相關路徑配置,主要有:annotation_path、classes_path、weights_path
其中,train.py裏面的batch_size默認是32(第57行),指每次處理時批量處理的數量,數值越大對機器的性能要求越高,因此可根據電腦的實際情況進行調高或調低
(5)訓練模型
經過以上的配置後,終於全部都準備好了,執行train.py就可以開始進行訓練。
訓練後的模型,默認保存路徑爲logs/000/trained_weights_final.h5,可以根據需要進行修改,位於train.py的第85行,可修改模型保存的名稱。
6、使用模型
完成模型的訓練之後,調用yolo.py即可使用我們訓練好的模型。
首先,修改yolo.py裏面的模型路徑、類別文件路徑,如下:
class YOLO(object):
_defaults = {
"model_path": 'logs/000/trained_weights_final.h5',
"anchors_path": 'model_data/yolo_anchors.txt',
"classes_path": 'model_data/my_classes.txt',
"score" : 0.3,
"iou" : 0.45,
"model_image_size" : (416, 416),
"gpu_num" : 1,
}
通過調用 YOLO類就能使用YOLO模型,爲方便測試,在yolo.py最後增加以下代碼,只要修改圖像路徑後,就能使用自己的yolo模型了
if __name__ == '__main__':
yolo=YOLO()
path = '/data/work/tensorflow/data/panda_test/1.jpg'
try:
image = Image.open(path)
except:
print('Open Error! Try again!')
else:
r_image, _ = yolo.detect_image(image)
r_image.show()
yolo.close_session()
執行後,可愛的熊貓就被乖乖圈出來了,呵呵
通過以上步驟,我們又學習了基於YOLO來訓練自己的目標檢測模型,這樣在應用中可以結合實際需求,使用SSD、YOLO訓練自己的數據,並從中選擇出效果更好的目標檢測模型。
後面還會陸續推出更多【AI實戰】內容,敬請留意。
推薦相關閱讀
- 【AI實戰】快速掌握TensorFlow(一):基本操作
- 【AI實戰】快速掌握TensorFlow(二):計算圖、會話
- 【AI實戰】快速掌握TensorFlow(三):激勵函數
- 【AI實戰】快速掌握TensorFlow(四):損失函數
- 【AI實戰】搭建基礎環境
- 【AI實戰】訓練第一個模型
- 【AI實戰】編寫人臉識別程序
- 【AI實戰】動手訓練目標檢測模型(SSD篇)
- 【AI實戰】動手訓練目標檢測模型(YOLO篇)
- 【精華整理】CNN進化史
- 大話卷積神經網絡(CNN)
- 大話循環神經網絡(RNN)
- 大話深度殘差網絡(DRN)
- 大話深度信念網絡(DBN)
- 大話CNN經典模型:LeNet
- 大話CNN經典模型:AlexNet
- 大話CNN經典模型:VGGNet
- 大話CNN經典模型:GoogLeNet
- 大話目標檢測經典模型:RCNN、Fast RCNN、Faster RCNN
- 大話目標檢測經典模型:Mask R-CNN
- 27種深度學習經典模型
- 淺說“遷移學習”
- 什麼是“強化學習”
- AlphaGo算法原理淺析
- 大數據究竟有多少個V
- Apache Hadoop 2.8 完全分佈式集羣搭建超詳細教程
- Apache Hive 2.1.1 安裝配置超詳細教程
- Apache HBase 1.2.6 完全分佈式集羣搭建超詳細教程
- 離線安裝Cloudera Manager 5和CDH5(最新版5.13.0)超詳細教程
關注本人公衆號“大數據與人工智能Lab”(BigdataAILab),獲取更多信息。