轉載自:
https://blog.csdn.net/stesha_chen/article/details/81976415
謝謝大佬,這是我見過最適合入門者的文檔。
前期準備
訓練flower數據集(包括fine-tuning)
訓練自己的數據集(包括fine-tuning)
前期準備
前期瞭解
在tensorflow models中有官方維護和非官方維護的models,official models就是官方維護的models,裏面使用的接口都是一些官方的接口,比如tf.layers.conv2d之類。而research models是tensorflow的研究人員自己實現的一些流行網絡,不受官方支持,裏面會用到一些slim之類的非官方接口。但是因爲research models實現的網絡非常多,而且提供了完整的訓練和評價方案,所以我們現在基於research models中的實現來部署網絡。
環境配置
首先要保證tf.contrib.slim在你的tensorflow環境中是存在的,運行下面的腳本保證沒有錯誤發生。
python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"
base代碼準備
TF的庫裏面沒有TF-slim的內容,所以我們需要將代碼clone到本地
- cd $HOME/workspace
- git clone https://github.com/tensorflow/models/
運行以下腳本確定是否可用
- cd $HOME/workspace/models/research/slim
- python -c "from nets import cifarnet; mynet = cifarnet.cifarnet"
其實我們只需要使用research中的slim的代碼,所以我是直接拷貝了slim的代碼到本地,基於slim代碼進行修改。
訓練flower數據集
下載數據並創建tfrecord
官網提供了下載並且轉換數據集的方法,運行如下腳本即可,腳本會直接下載flower數據集並且存儲爲TFRecord的格式。
- $ python download_and_convert_data.py \
- --dataset_name=flowers \
- --dataset_dir=./tmp/data/flowers
爲何官網要使用TFRecord呢?因爲TFRecord和tensorflow內部有一個加速機制。實際讀取tfrecord數據時,先以相應的tfrecord文件爲參數,創建一個輸入隊列,這個隊列有一定的容量,在一部分數據出隊列時,tfrecord中的其他數據就可以通過預取進入隊列,這個過程和網絡的計算是獨立進行的。也就是說,網絡每一個iteration的訓練不必等待數據隊列準備好再開始,隊列中的數據始終是充足的,而往隊列中填充數據時,也可以使用多線程加速。
下載pre-trained checkpoint
每個網絡對應的checkpoint可以從官網上找到,官網也提供了下載inception v3的checkpoint的例子
- $ mkdir ./tmp/checkpoints
- $ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
- $ tar -xvf inception_v3_2016_08_28.tar.gz
- $ mv inception_v3.ckpt ./tmp/checkpoints
- $ rm inception_v3_2016_08_28.tar.gz
從頭開始訓練
官網上提供了從頭開始訓練的例子,我根據我本地訓練flowers數據集的存儲位置而對腳本稍做修改
- python train_image_classifier.py --train_dir=./tmp/train_logs \
- --dataset_name=flowers --dataset_split_name=train \
- --dataset_dir=./tmp/flowers --model_name=inception_v3
訓練過程會打印出loss值
- ......
- INFO:tensorflow:global step 10: loss = 3.3827 (0.384 sec/step)
- INFO:tensorflow:global step 20: loss = 2.9981 (0.389 sec/step)
- INFO:tensorflow:global step 30: loss = 3.8143 (0.392 sec/step)
- INFO:tensorflow:global step 40: loss = 3.3529 (0.385 sec/step)
- INFO:tensorflow:global step 50: loss = 3.1890 (0.388 sec/step)
- INFO:tensorflow:global step 60: loss = 2.2893 (0.389 sec/step)
- INFO:tensorflow:global step 70: loss = 2.5434 (0.386 sec/step)
- INFO:tensorflow:global step 80: loss = 3.1224 (0.386 sec/step)
- INFO:tensorflow:global step 90: loss = 3.4845 (0.387 sec/step)
- INFO:tensorflow:global step 100: loss = 2.2984 (0.391 sec/step)
- INFO:tensorflow:global step 110: loss = 2.5087 (0.387 sec/step)
- INFO:tensorflow:global step 120: loss = 2.8148 (0.391 sec/step)
- INFO:tensorflow:global step 130: loss = 2.4258 (0.390 sec/step)
- INFO:tensorflow:global step 140: loss = 2.9289 (0.391 sec/step)
- INFO:tensorflow:global step 150: loss = 2.5775 (0.391 sec/step)
- INFO:tensorflow:global step 160: loss = 2.5603 (0.390 sec/step)
- INFO:tensorflow:global step 170: loss = 2.8423 (0.392 sec/step)
- INFO:tensorflow:global step 180: loss = 2.3163 (0.388 sec/step)
- ......
tensorboard
打開tensorboard,tensorboard --logdir=./tmp/train_logs
可以查看tensorboard
Fine-tuning
--checkpoint_path:指定checkpoint文件的路徑。
--checkpoint_exclude_scopes:當pre-trained checkpoint對應的網絡最後一層分類的類別數量和現在數據集的類別數量不匹配時使用,可以指定checkpoint restore時哪些層的參數不恢復。
--trainable_scopes:如果只希望某些層參與訓練,其他層的參數固定時,就使用這個flag,在這個flag中指定需要訓練的參數。
- python train_image_classifier.py \
- --train_dir=./tmp/train_logs \
- --dataset_dir=./tmp/flowers \
- --dataset_name=flowers \
- --dataset_split_name=train \
- --model_name=inception_v3 \
- --checkpoint_path=./tmp/inception_v3_checkpoints/inception_v3.ckpt \
- --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
- --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
使用pre-trained checkpoint,loss會很快下降到一個比較小的值。
評價
使用上一個步驟訓練出來的checkpoint進行評估。注意一個差別,fine-tuning時--checkpoint_path需要指定到具體文件,但是評估的時候--checkpoint_path只需要指定到文件夾路徑即可,代碼會根據文件夾下的內容自動選定以最新的checkpoint來進行評估。
- python eval_image_classifier.py \
- --alsologtostderr \
- --checkpoint_path=./tmp/train_logs \
- --dataset_dir=./tmp/flowers \
- --dataset_name=flowers \
- --dataset_split_name=validation \
- --model_name=inception_v3
結果如下:
- INFO:tensorflow:Restoring parameters from ./tmp/train_logs/model.ckpt-0
- INFO:tensorflow:Evaluation [1/4]
- INFO:tensorflow:Evaluation [2/4]
- INFO:tensorflow:Evaluation [3/4]
- INFO:tensorflow:Evaluation [4/4]
- 2018-08-23 18:07:13.349935: I tensorflow/core/kernels/logging_ops.cc:79] eval/Recall_5[1]
- 2018-08-23 18:07:13.350030: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.1675]
- INFO:tensorflow:Finished evaluation at 2018-08-23-10:07:13
保存模型
- python export_inference_graph.py \
- --alsologtostderr \
- --model_name=inception_v3 \
- --output_file=./tmp/inception_v3_inf_graph.pb
可以將模型導出,後續可以直接load這個模型來使用
小結
現有數據的訓練方式就介紹完了,基本上腳本都可以解決,所以要訓練自己的數據集就需要模仿這些代碼的實現。
訓練自己的數據
創建自己的數據集
首先要準備自己的數據集,保證相同類別的圖片放在同一個文件夾下,文件夾的名字就是這個類別的名稱。注意,圖片數據最好備份一份,因爲執行完後圖片數據會全部被刪除,只保留生成的tfrecord文件,除非修改代碼刪除這個步驟
接着需要仿照download_and_convert_flowers.py中對flowers數據轉tfrecord的處理,來實現對自己的數據轉tfrecord的處理。
主要以下幾個改動點:
1.創建convert_mydata.py文件,等同於download_and_convert_flowers.py,因爲我們自己的數據不用下載,所以文件命名爲convert_mydata.py
2.在download_and_convert_data.py中添加處理,這樣運行download_and_convert_data.py時,傳入mydata數據集就可以走到convert_mydata.py裏的run函數。
- # add by stesha
- elif FLAGS.dataset_name == 'mydata':
- convert_mydata.run(FLAGS.dataset_dir)
3.convert_mydata.py的實現基本和download_and_convert_flowers.py類似,只是去掉裏面關於download部分的代碼,比如run函數中去掉dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
4._NUM_SHARDS表示有多少類別,_NUM_VALIDATION表示用多少張圖片作爲validation,根據實際情況填寫即可。
5.如果不希望自己的圖片數據在執行完後被刪掉,可以去掉run中_clean_up_temporary_files(dataset_dir)代碼。
參考代碼:convert_mydata.py
代碼實現後運行下面的腳本就可以將數據轉換成tfrecord格式了。
- python download_and_convert_data.py \
- --dataset_name=mydata \
- --dataset_dir=./data/mydata
下載pre-trained checkpoint
訓練自己的數據我打算用準確率相對比較高的inception v4,所以我們需要下載inception v4的checkpoint。
下載地址:http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz
下載完成後解壓放入一個文件夾中,比如我放入了./data/checkpoint中
從頭開始訓練
訓練時數據需要從tfrecord中讀取出來,所以代碼需要稍作改動
1.在dataset_factory.py中增加mydata數據集
2.創建mydata.py,參考flowers.py的實現,基本需要改動的只有SPLITS_TO_SIZES和_NUM_CLASSES。前者只需要將測試集和訓練集的大小寫入,後者分類的數量。參考:mydata.py
準備好後,只需要運行下面的腳本就可以開始訓練了,新的checkpoint文件會存放在指定的train_dir中。
- python train_image_classifier.py --train_dir=./data/train_logs \
- --dataset_name=mydata --dataset_split_name=train \
- --dataset_dir=./data/mydata --model_name=inception_v4
fine-tuning
- python train_image_classifier.py \
- --train_dir=./data/train_logs \
- --dataset_dir=./data/mydata \
- --dataset_name=mydata \
- --dataset_split_name=train \
- --model_name=inception_v4 \
- --checkpoint_path=./data/checkpoint/inception_v4.ckpt \
- --checkpoint_exclude_scopes=InceptionV4/Logits,InceptionV4/AuxLogits \
- --trainable_scopes=InceptionV4/Logits,InceptionV4/AuxLogits
評估
- python eval_image_classifier.py \
- --alsologtostderr \
- --checkpoint_path=./data/train_logs \
- --dataset_dir=./data/mydata \
- --dataset_name=mydata \
- --dataset_split_name=validation \
- --model_name=inception_v4
預測
tf-slim中並沒有提供predict某張圖片的腳本,我這邊簡單實現了一下,可以作爲參考。predict.py
- python predict.py --model_name=inception_v4 \
- --predict_file=./backup/mydata/km1_back/km1_back.jpg \
- --checkpoint_path=./data/train_logs
結語
使用tensorflow的slim model來訓練自己的數據集還是很簡單的,基本上要改動的代碼不多,這樣能夠方便我們很快的實施自己的想法,而且基於已經訓練好的checkpoint來fine-tuning很快也能得到不錯的精確度,使神經網絡的部署更加方便快捷。