注:此博客基於tensorflow官網完整教程,具體數據下載處可去http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html
MNIST是在機器學習領域中的一個經典問題。該問題解決的是把28x28像素的灰度手寫數字圖片識別爲相應的數字,其中數字的範圍從0到9.
60000行訓練數據集 mnist.train
10000行測試數據集 mnist.test
mnist.train.images [60000,784] 維度1索引圖片,維度2索引像素點
mnist.train.labels [60000,10] 標籤數據”one-hot vectors”(一個one-hot向量除了一位數字爲1以
外,其餘爲0)
1、下載安裝數據集
提供一份自動下載和安裝數據集 input_data.py
from tensorflow.examples.tutorials.mnist import input_data
mnist1 = input_data.read_data_sets("MINST_data", one_hot=True)
'''one-hot,Label是一個10維的向量,只有一個值爲1,如果是數字0,那麼對應的Label就是[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]。'''
2、定義
placeholder是佔位符,第一個參數是數據類型dtype,第二個是tensor的shape。
Softmax Regression會對10類分別估算出一個概率,例如是0的概率爲80%,數字1的概率是2%,那麼它就會取最後那個概率最大的那個數
import tensorflow as tf
sess = tf.InteractiveSession() # 使用這個命令會將這個session註冊爲默認的session,之後也會默認在這個session裏跑。
x = tf.placeholder(tf.float32, [None, 784])
'''接下來就是創建權重和偏差,這裏因爲就舉個例子,所以就初始化爲0就可以了,如果是其它複雜的例子,對初始化比較敏感的話,就不能這麼簡單的進行初始化了。'''
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
#Softmax Regression的實現
y = tf.nn.softmax(tf.matmul(x, W) + b)
3、損失函數,優化算法
根據損失來找到最好的模型
y是預測的概率,y_是正確的標籤
reduction_indices = [1]: 一種壓縮方法具體見我的其他博文
reduce_mean:平均值
reduce_sum:求和
GradientDescentOptimizer(0.5):梯度下降,學習率爲0.5
#交叉熵
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices = [1]))
1. #使用隨機梯度下降進行優化,這裏把學習率設爲0.5,使用全局參數初始化器並直接執行它的run。
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
init=tf.global_variables_initializer()
sess.run(init)
4、訓練數據
迭代執行訓練操作
迭代1000次,每次100
for i in range(1000):
batch = mnist1.train.next_batch(100)
sess.run(train_step,feed_dict={x: batch[0], y: batch[1]})
5、準確率
argmax函數,給出某個tensor對象在某一堆上其數據最大值的所在的索引值。
(y,1):y 所索引的向量,1表示按行索引,0表示按列索引。
#計算分類是否正確,給出一組布爾值
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
#計算準確率,先轉換爲浮點數,取平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval({x: mnist1.test.images, y_: mnist1.test.labels}))
此預測模型準確率大概爲91%左右,準確率不夠高,原因是因爲這個模型比較簡單!