使用tensorflow開源框架搭建一個簡單的手寫數字識別(3)————測試集

mnist_test.py

#coding:utf-8
import time
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
TEST_INTERVAL_SECS = 5

def test(mnist):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
        y = mnist_forward.forward(x, None)

        ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_store = ema.variables_to_restore()  
        saver = tf.train.Saver(ema_store)

        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  #求平方差
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #計算準確率

        while True:
            with tf.Session() as sess:  
                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH) #加載ckpt模型
                if ckpt and ckpt.model_checkpoint_path: #如果已有ckpt模型則恢復
                    print(ckpt.model_checkpoint_path) 
                    saver.restore(sess, ckpt.model_checkpoint_path) #恢復會話
                    global_step = ckpt.model_checkpoint_path.split("/")[-1].split('-') [-1]  #恢復輪數
                    accuracy_score = sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels})   #計算準確率
                    print("After %s training step(S), test accuracy = %g" % (global_step, accuracy_score))    #打印提示
                else: #如果沒有模型
                    print('No checkpoint file found')  #給出提示
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():
    mnist = input_data.read_data_sets(".\\MNIST_data\\", one_hot = True)
    test(mnist)


if __name__ == "__main__":
    main()

(1)滑動平均值模型的讀取

使用variables_to_restore函數,可以使在加載模型的時候將影子變量直接映射到變量本身。所以我們在獲取變量的滑動平均值的時候只需要獲取變量本身值而不需要獲取影子變量值。

我對這個函數的理解是這個函數將原模型的滑動平均值重新命名爲原參數名(也就是將影子變量直接映射到變量本身)。這樣就完成對滑動平均值的加載。最後,我們發現這個函數的主要功能是重命名字典。

更加詳細具體的內容大家可以參考這個博客:

http://www.cnblogs.com/shiluoliming/p/9023307.html

(2)數據類型的轉換

使用tf.cast(x, dtype, name=None)

將x的數據格式轉化爲dtype。

 

(3)模型的保存與讀取tf.train.Saver()

保存和恢復都需要實列化一個tf.train.Saver.

saver = tf.train.Saver()

之後在訓練循環(也就是上一個博客的mnist_backward.py文件訓練)中,定期調用saver.save()方法,向文件夾中寫入包含當前模型中所有可訓練變量的checkpoint文件。

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章