立即學習:https://edu.csdn.net/course/play/24719/279510?utm_source=blogtoedu
目錄
一、用訓練模型進行預測代碼
import cv2
import tensorflow as tf
with tf.Session() as sess:
"""加載模型"""
loader = tf.train.import_meta_graph("./model/model.meta")
loader.restore(sess, './model/model')
"""開始識別"""
capture = cv2.VideoCapture(0)
while True:
ret, frame = capture.read() # 獲取一幀
show_img = frame.copy() # show_img是原圖像frame的拷貝
crop_img = frame[100:200, 100:200] # crop_img是在原圖像frame上的部分截取
cv2.rectangle(show_img, (100, 100), (200, 200), (0, 0, 255), 5) # 在show_img上畫出截取部分的框框
cv2.imshow('frame', show_img) # 顯示圖片show_img
k = cv2.waitKey(1) # OpenCV延遲1毫秒、同時檢測是否有按鍵被按下(如果有鍵被按下時,將鍵值返回給q)
# 按下Q鍵退出
if k == ord('q'):
break
# 按下S鍵進行識別
elif k == ord('s'):
out = sess.run("output:0", feed_dict={"data_in:0": [crop_img]})
res = ['我', '好', '帥'][out[0]]
print("識別結果:", res)
cv2.imshow("crop", crop_img)
capture.release()
cv2.destroyAllWindows()
二、思路總結
1、模型的導入
要有一個會話sess,要構建好模型載入器,要將模型文件路徑填入,恢復到會話中
2、測試圖像的輸入
和準備數據集時同樣的方法,使用cv調用攝像頭,框選部分圖像,作爲輸入
3、根據測試圖像進行預測
在會話sess中運行模型中與輸出相應的"name:後綴",同時將測試圖像的數據feed給模型中與輸入相應的"name:後綴",得到的返回值即爲輸出結果。
這裏需要注意的是,本質上sess讀取的模型中就蘊含了原來訓練時神經網絡的結構,所以name要和原來訓練時神經網絡中變量的name相對應
三、API總結
API |
作用 |
使用示例 |
tf.train.import_meta_graph |
構建tf會話載入器 |
loader = tf.train.import_meta_graph("./model/model.meta") |
*.restore |
將訓練好的模型文件加載到會話中 |
loader.restore(sess, './model/model') |