1.起始,因爲TensorFlow優化模型使用方法,引入了tensorflow hub,讓使用更簡單,但導致後果就是以前的教程基本不能使用,官方例程因爲hub的模型基本是使用URL模式加載,剛好上不去,這就尷尬了,需要自己慢慢摸索,對萌新及其不友好,本篇主要記錄本次調試過程,方便後人繞坑。本次使用google 已經訓練好的模型inception_v3,然後對最後一層進行重新訓練,以滿足我們需要的分類要求
2.環境要求
安裝TensorFlow,CPU和GPU版本都可以,GPU比較快而已,CPU直接使用最新的即可,目前最新的GPU版本對應的cuda_10.0 和 tensorflow-gpu 1.14.0,使用cuda_10.1會出現TF無法使用cuda的問題,懷疑是Anaconda沒有及時同步導致。
3.準備工作
(1)前往https://github.com/tensorflow/tensorflow ,下載對應的TensorFlow源碼。
(2)安裝hub, pip install tensorflow-hub ,並前往https://github.com/tensorflow/hub ,下載對應的tensorflow-hub源碼
(3)準備好自己需要分類的圖片,按類型劃分好文件名字,我這裏使用的是官方提供的數據集 flower_photos,需要的自己去下載,不用科學上網
(4).在下載下來的hub源碼中找到hub-master\examples\image_retraining文件夾,運行retrain.py,開始訓練。不能科學上網的會在這裏被卡住,我這裏提供一個野生方法,本地化模型,更改模型爲本地加載,參考連接https://zhuanlan.zhihu.com/p/64069911。
下載模型文件,示例如下:
模型路徑:https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3
下載模型路徑:https://storage.googleapis.com/tfhub-modules/google/imagenet/inception_v3/feature_vector/3.tar.gz
下載後模型需要解壓纔可以正常使用,然後,運行腳本開始訓練
python H:\tf_py\hub-master\examples\image_retraining\retrain.py ^
--image_dir H:\tf_py\image_retrain\flower_photos\flower_photos ^
--tfhub_module H:\tf_py\image_retrain\inception\3 ^
--saved_model_dir H:\tf_py\image_retrain\inception\4
pause
4.檢測訓練好的模型
因爲我使用的是鮮花( 玫瑰 鬱金香 向日葵 雛菊 蒲公)的分類訓練,所以我去百度下了很多這種類型的圖片進行測試。
不幸的是TensorFlow上面的測試例程,因爲移植等問題,已經對不上這個模型的測試例程了,於是我自己碼了一個心塞。
示例代碼:
import tensorflow as tf
import tensorflow_hub as hub
import os
import re
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras import layers
saved_model_dir = 'image_retrain/inception/4'
label_lookup_path = 'image_retrain/output_labels.txt'
image_path = 'image_retrain/test/'
class NodeLookup(object):
def __init__(self):
self.node_lookup = self.load(label_lookup_path)
def load(self,label_lookup_path):
proto_as_ascii_lines = tf.gfile.GFile(label_lookup_path).readlines()
node_id_to_name = {}
#一行一行讀取數據
for uid,line in enumerate(proto_as_ascii_lines):
#去掉換行符
line = line.strip('\n')
node_id_to_name[uid] = line
return node_id_to_name
#傳入分類編號1-1000返回分類名稱
def id_to_string(self,node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
with tf.Session() as sess:
# 如果不知道模型具體信息 可以使用saved_model_cli.py 查看該模型的輸入 輸出數據格式要求以及關鍵的Signature簽名
#python H:\tf_py\tensorflow-master\tensorflow\python\tools\saved_model_cli.py show --dir ....\mode\ --all
meta_graph_def = tf.saved_model.loader.load(sess,["serve"], saved_model_dir)
graph = tf.get_default_graph()
oputs = sess.graph.get_tensor_by_name('final_result:0')
input_image = sess.graph.get_tensor_by_name('Placeholder:0')
#遍歷目錄
for root,dirs,files in os.walk(image_path):
for file in files:
#載入圖片
image_data = Image.open(os.path.join(root,file)).resize([299,299])
image_data_array = np.array(image_data)/255.0
image_data_shape = np.reshape(image_data_array,[299,299,3])
#傳入圖片不能是tensor類型 這裏使用np轉化成矩陣數組格式
#原因出在tf.reshape(),因爲網絡訓練時用placeholder定義了輸入格式,所以輸入不能用tensor,
#而tf.reshape()返回結果就是一個tensor了,所以輸入會報錯。
predictions = sess.run(oputs,{input_image:[image_data_shape]})
predictions = np.squeeze(predictions) #轉化爲一維數據
image_path = os.path.join(root,file)
print(image_path)
plt.imshow(image_data_shape)
plt.axis('off')
plt.show()
#排序 取概率最大的5個值 然後倒序
top_k = predictions.argsort()[-5:][::-1]
node_lookup = NodeLookup()
for node_id in top_k:
#獲取分類名稱
human_string = node_lookup.id_to_string(node_id)
#獲取分類置信度
score = predictions[node_id]
print('%s (score = %.5f)' %(human_string,score))
print()
運行結果:
附錄一下錯誤調試,
GPU的童鞋需要注意,訓練的時候很容易出現cudnn錯誤,解決方法如下:
1.cudnn創建錯誤,環境沒錯的話就是顯卡內存出錯了,修改爲按需分配
Problem:Could not create cudnn handle: CUDNN_STATUS_ALLOC_FAILED
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
2.如果不清楚的模型的輸入輸出,使用saved_model_cli.py 可以解決很多問題,我被這個輸入數據卡了3天,才找到這個解決方案。
3.訓練和測試時很容易出現莫名其妙的錯誤,這個時候最好重啓一下python服務,或者刪除緩存文件,否則你會崩潰的
4.如果能科學上網,儘量科學上網把,太折騰人了