本章的前期工作以及神經網絡的搭建:https://blog.csdn.net/ileopard/article/details/102763645
一、可視化界面設計
使用 tkinter來設計可視化界面
1.新建窗體
from tkinter import Label, Menu, DoubleVar, Button, Tk, filedialog
window = Tk() # 創建窗口
window.title("用戶頁面") # 窗口標題
window.geometry('240x360') # 窗口大小,小寫字母x
# 這裏可以在窗體內添加其他的控件
# 以上是窗口的主體
window.mainloop() # 結束(不停循環刷新)
2.添加控件(Label, Menu, DoubleVar, Button)
# 最後這個菜單欄沒有用到,但是還是把它放在這,以後可能會用到
# ---------------窗口菜單欄
menubar = Menu(window) # 在窗口上添加菜單欄
filemenu = Menu(menubar, tearoff=0) # filemenu放在menu中
submenu = Menu(filemenu) # submenu放在filemenu中
ssubmenu = Menu(submenu) # ssubmenu放在submenu中
menubar.add_cascade(label='File', menu=filemenu) # add_cascade用來創建下拉欄,filemenu命名爲File
filemenu.add_command(label='Open', command=Open_image) # add_command用來創建命令欄,不可有子項
filemenu.add_cascade(label='1', menu=submenu) # submenu 命名爲1
submenu.add_cascade(label='2', menu=ssubmenu) # ssubmenu 命名爲2
window.config(menu=menubar) # 創建完畢
# --------------------------
下面是本次所用到的控件:
# label,如果要在label中設置圖片,記得一定要設置參數bitmap
Input_image = Label(width=200,
height=200,
bitmap='warning',
bg='white').grid(row=0, column=0, padx=20)
# 使用grid佈局,像表格一樣的佈局,其中padx表示x距離外部邊界的大小,ipadx表示與內部的
# sticky表示位置,w表示西,e表示東
testLabel = Label(window,
text="testAccuracy: ", # 文本
font=('Arial', 10), # 字體和大小
width=10,
height=2, # 字體所佔的寬度和高度
).grid(row=1, column=0, sticky='w', pady=5, padx=36)
testAccuracy = Label(window,
textvar=textTest, # 文本
font=('Arial', 10), # 字體和大小
width=10,
height=2, # 字體所佔的寬度和高度
bg='white'
).grid(row=1, column=0, sticky='e', pady=5, padx=36)
textTest.set(0.0)
# 使用Button。
startB = Button(
window,
text='開始',
width=8, height=2,
command=application # 執行函數體,而不是得到函數執行的結果
).grid(row=3, column=0, pady=5)
這樣大致就做好了界面設計
二、具體功能的實現
大致流程:
下面主要進行三步:
- 從電腦中輸入圖片(Open_image())
- 預處理輸入圖片()
- 加載之前訓練好的模型進行預測
1.從電腦中輸入圖片
- 使用filedialog從電腦中選擇圖片,返回絕對路徑。
- 之後通過該路徑,打開圖片,將其放入可視化界面的Input_image中。
- 最後返回該圖片
# 打開電腦中的圖片
def Open_image():
global Input_image, File
File = filedialog.askopenfilename(parent=window,
initialdir=ImagePath,
title='Choose an image.')
img = Image.open(File)
img_resized = img.resize((28 * 4, 28 * 4), Image.ANTIALIAS)
filename = ImageTk.PhotoImage(img_resized)
Input_image = Label(image=filename)
Input_image.image = filename
Input_image.grid(row=0, column=0)
return img_resized
2.預處理輸入圖片
- 將圖片resize爲28*28大小
- 將圖片灰度化
- 由於模型的要求是黑底白字,但輸入的圖是白底黑字,所以需要對每個像素點的值改爲 255 減去原值以得到互補的反色。
- 把圖片reshape成一維數組(784個像素點)
- 將現有的RGB圖從0-255之間的數變爲0-1之間的浮點數
- 返回預處理好的圖片
# 預處理函數
def predicted(img):
# 將圖片resize爲28*28大小
reIm = img.resize((28, 28), Image.ANTIALIAS)
# 將圖片灰度化
im_arr = np.array(reIm.convert('L'))
threshold = 50 # 設定合理的閾值
# 二值化
for i in range(28):
for j in range(28):
im_arr[i][j] = 255 - im_arr[i][j]
if im_arr[i][j] < threshold:
im_arr[i][j] = 0
else:
im_arr[i][j] = 255
# 將圖片轉化爲一維數組(784個像素點)
nm_arr = im_arr.reshape([1, 784])
nm_arr = nm_arr.astype(np.float32)
# 將現有的RGB圖從0-255之間的數變爲0-1之間的浮點數
img_ready = np.multiply(nm_arr, 1.0 / 255.0)
return img_ready
3.加載模型進行預測
- 復現之前定義的計算圖(神經網絡),記得佔位
- 通過checkpoint文件找到最新保存的模型位置
- 進行預測,返回預測值
# 加載模型
def restore_model(testPicArr):
# 復現之前定義的計算圖
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, None)
# 得到概率最大的預測值
preValue = tf.argmax(y, 1)
# 計算模型在測試集上的準確率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 實現滑動平均模型,參數MOVING_AVERAGE_DECAY用於控制模型更新的速度。訓練過程會對每一個變量維護一個影子變量。
# 這個影子變量的初始值就是相應變量的初始值,每次更新時,影子變量就會隨之更新
variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
# 通過checkpoint文件找到最新保存的模型位置
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={x: testPicArr})
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
textTest.set(accuracy_score)
return preValue
else:
print("No checkpoint file found")
return -1
三、預測結果
這樣可視化預測就完成了。
但是有一個問題就是該模型只能用於預測白底黑字的手寫數字圖片,還需要改進。