# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist/", one_hot=True)
# 定義迴歸模型
X = tf.placeholder(tf.float32, [None, 784]) # 輸入的X的值
Y = tf.placeholder(tf.float32, [None, 10]) # 輸出真實值
def create_model():
w1 = tf.Variable(tf.random_uniform([784, 1024], -1, 1))
b1 = tf.Variable(tf.random_uniform([1024], -1, 1))
y1 = tf.sigmoid(tf.matmul(X, w1) + b1)
y1 = tf.nn.dropout(y1, keep_prob=1)
w2 = tf.Variable(tf.truncated_normal([1024, 512]))
b2 = tf.Variable(tf.truncated_normal([512]))
y2 = tf.sigmoid(tf.matmul(y1, w2) + b2)
y2 = tf.nn.dropout(y2, keep_prob=1)
w3 = tf.Variable(tf.random_uniform([512, 128], -1, 1))
b3 = tf.Variable(tf.random_uniform([128], -1, 1))
y3 = tf.sigmoid(tf.matmul(y2, w3) + b3)
y3 = tf.nn.dropout(y3, keep_prob=1)
w4 = tf.Variable(tf.truncated_normal([128, 10]))
b4 = tf.Variable(tf.truncated_normal([10]))
outputs = tf.matmul(y3, w4) + b4
# softmax loss
cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=outputs, labels=Y))
#train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy_loss)
train_op = tf.train.AdamOptimizer(0.001).minimize(cross_entropy_loss)
#預測、準確率
pred = tf.equal(tf.argmax(outputs, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(pred, tf.float32))
return train_op, cross_entropy_loss, accuracy
with tf.Session() as sess:
train_op1, cross_entropy_loss1, accuracy1 = create_model()
sess.run(tf.global_variables_initializer())
for i in range(10000000):
xs, ys = mnist.train.next_batch(100)
_, loss_, acc_ = sess.run([train_op1, cross_entropy_loss1, accuracy1], feed_dict={X: xs, Y: ys})
if i % 100 != 0:
continue
# 測試當前模型在訓練數據、測試數據、驗證數據中的準確率,
# 數據較少,驗證所有的數據
print("step:%s, loss: %s, train_acc:%s, test_acc:%s, valid_acc:%s" % (i, loss_,
sess.run(accuracy1, feed_dict={X: mnist.train.images, Y: mnist.train.labels}),
sess.run(accuracy1, feed_dict={X: mnist.test.images, Y: mnist.test.labels}),
sess.run(accuracy1, feed_dict={X: mnist.validation.images, Y: mnist.validation.labels})))
BP神經網絡基於TensorFlow的mnist數據集分類
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.