BP神經網絡基於TensorFlow的mnist數據集分類

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist/", one_hot=True)

# 定義迴歸模型
X = tf.placeholder(tf.float32, [None, 784]) # 輸入的X的值
Y = tf.placeholder(tf.float32, [None, 10])  # 輸出真實值


def create_model():
    w1 = tf.Variable(tf.random_uniform([784, 1024], -1, 1))
    b1 = tf.Variable(tf.random_uniform([1024], -1, 1))
    y1 = tf.sigmoid(tf.matmul(X, w1) + b1)
    y1 = tf.nn.dropout(y1, keep_prob=1)

    w2 = tf.Variable(tf.truncated_normal([1024, 512]))
    b2 = tf.Variable(tf.truncated_normal([512]))
    y2 = tf.sigmoid(tf.matmul(y1, w2) + b2)
    y2 = tf.nn.dropout(y2, keep_prob=1)

    w3 = tf.Variable(tf.random_uniform([512, 128], -1, 1))
    b3 = tf.Variable(tf.random_uniform([128], -1, 1))
    y3 = tf.sigmoid(tf.matmul(y2, w3) + b3)
    y3 = tf.nn.dropout(y3, keep_prob=1)

    w4 = tf.Variable(tf.truncated_normal([128, 10]))
    b4 = tf.Variable(tf.truncated_normal([10]))
    outputs = tf.matmul(y3, w4) + b4

    # softmax loss
    cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=outputs, labels=Y))
    #train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy_loss)
    train_op = tf.train.AdamOptimizer(0.001).minimize(cross_entropy_loss)

    #預測、準確率
    pred = tf.equal(tf.argmax(outputs, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(pred, tf.float32))

    return train_op, cross_entropy_loss, accuracy


with tf.Session() as sess:
    train_op1, cross_entropy_loss1, accuracy1 = create_model()

    sess.run(tf.global_variables_initializer())

    for i in range(10000000):
        xs, ys = mnist.train.next_batch(100)
        _, loss_, acc_ = sess.run([train_op1, cross_entropy_loss1, accuracy1], feed_dict={X: xs, Y: ys})
        if i % 100 != 0:
            continue
        # 測試當前模型在訓練數據、測試數據、驗證數據中的準確率,
        # 數據較少,驗證所有的數據
        print("step:%s, loss: %s, train_acc:%s, test_acc:%s, valid_acc:%s" % (i, loss_,
                sess.run(accuracy1, feed_dict={X: mnist.train.images, Y: mnist.train.labels}),
                sess.run(accuracy1, feed_dict={X: mnist.test.images, Y: mnist.test.labels}),
                sess.run(accuracy1, feed_dict={X: mnist.validation.images, Y: mnist.validation.labels})))



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章