[DL] 識別圖中模糊的手寫數字

MNIST是一個入門級的計算機視覺數據集。當我們開始學習編程時,第一件事往往是學習打印Hello World。在機器學習入門的領域裏,我們會用MNIST數據集來實驗各種模型。

#自動下載與安裝MNIST數據集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#打印MNIST中信息
print('輸入數據:', mnist.train.images)
print('訓練數據的shape:', mnist.train.images.shape)   #輸入數據的shape: (55000, 784)  55000張圖片  一張圖片爲784(28×28)像素
print('測試數據的shape:', mnist.test.images.shape)
print('驗證數據的shape:', mnist.validation.images.shape)

import pylab
im = mnist.train.images[100]    #第100張圖片
im = im.reshape(-1, 28)
pylab.imshow(im)
pylab.show()

tf.reset_default_graph()
#定義佔位符
#MNISRT數據集的維度是28×28=784 None代表第一個維度可以是任何長度。
# X代表能夠輸入任意數量的MNIST圖像,每一張圖展平成784維的向量
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])     #數字0~9,共10個類別

#定義學習參數   一個Variable代表一個可修改的張量
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

pred = tf.nn.softmax(tf.matmul(x,W) + b)   #softmax分類   構建正向傳播結構 表明只要模型中參數合適,通過具體的數據輸入,就能得到我們想要的分類

#定義一個反向傳播結構,編譯訓練模型,以得到合適的參數
#損失函數
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
#定義參數
learning_rate = 0.01
#使用梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

training_epochs = 25   #整個樣本集迭代25次
batch_size = 100       #訓練過程中一次取100條數據進行訓練
display_step = 1       #訓練一次就把具體的中間狀態顯示出來

#啓動session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())   #初始化OP   運行初始化

    #啓動循環開始訓練
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples/batch_size)
        #循環所有數據集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            #運行優化器
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})
            #計算平均loss值
            avg_cost += c/total_batch
        #顯示訓練中的詳細信息
        if (epoch+1) % display_step == 0:
            print("Epoch:", '%04d'%(epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print("Finished!")

    #測試model
    #由於是onehot編碼,這裏使用s了tf.argmax函數返回onehot編碼中數值爲1的那個元素的下標
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    #計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print("Accuracy:", accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))  #eval函數就是實現list、dict、tuple與str之間的轉化

    #保存模型
    saver = tf.train.Saver()
    model_path = './mnist_model_save.cpkt'
    save_path = saver.save(sess, model_path)
    print("Model saved in file: %s" % save_path)
import tensorflow as tf
from MNIST_1 import pred, x, y, mnist
import pylab
saver = tf.train.Saver()
model_path = './mnist_model_save.cpkt'
#讀取模型
print("Starting 2nd session...")
with tf.Session() as sess:
    #初始化變量
    sess.run(tf.global_variables_initializer())
    #恢復模型變量
    saver.restore(sess, model_path)
    #測試model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    #計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval, predv = sess.run([output, pred], feed_dict={x: batch_xs})
    print(outputval, predv, batch_ys)

    im = batch_xs[0]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()

    im = batch_xs[1]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()

參考自 《深度學習之Tensor Flow 入門、原理與進階實踐》–李金洪

發佈了74 篇原創文章 · 獲贊 18 · 訪問量 5萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章