【tensorflow學習】處理MNISTS數據集

【tensorflow學習】處理MNISTS數據集

數據地址

數據集下載

模型訓練及預估

#encoding=utf8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import pandas as pd
import argparse
import math
import numpy as np

def parse_arg():
    parser = argparse.ArgumentParser("Training for MNIST model.")
    parser.add_argument(
        "--train_data_dir",
        type=str,
        required=True,
        help="The path of training data.")
    parser.add_argument(
        "--test_data_dir",
        type=str,
        help="The path of test data.")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=10,
        help="The number of batch size.")
    parser.add_argument(
        "--train_round",
        type=int,
        default=10,
        help="The number of train round.")
    return parser.parse_args()
#對label進行熱編碼
def one_hot(index):
    default = np.array([0.0]*10)
    default[index[0]] = 1
    return default

#從csv中讀取訓練數據
def read_data(filename):
    train_data_df = pd.read_csv(filename, sep=',', header=None)
    label = train_data_df.iloc[:,0:1].values
    label = np.array(list(map(one_hot, list(label))))
    
    feature = train_data_df.iloc[:,1:].values
    return label, feature

def generate_batch(feature, label, batch_size=1024):
    m = feature.shape[0]
    batch_num = math.ceil(m * 1.0 / batch_size)
    feature_batch_list = []
    label_batch_list = []
    for id in range(batch_num):
        label_batch_list.append(label[max(0,id * batch_size):min((id+1) * batch_size, m)])
        feature_batch_list.append(feature[max(0,id * batch_size):min((id+1) * batch_size, m)])
    return feature_batch_list, label_batch_list, batch_num 

#定義模型結構
def model(args, label, feature):
    feature_num = feature.shape[1]
    label_size = label.shape[1]
    n = feature_num
    x = tf.placeholder(dtype="float", name='x', shape=[None, n]) 
    y = tf.placeholder(dtype="float", name='Y', shape=[None, label_size]) 
    W1 = tf.Variable(tf.truncated_normal([n, 2000],stddev=0.1))
    bias1 = tf.Variable(tf.zeros([2000])+ 0.1)
    L1 = tf.nn.sigmoid(tf.matmul(x, W1) + bias1)
    W2 = tf.Variable(tf.truncated_normal([2000, 2000],stddev=0.1))
    bias2 = tf.Variable(tf.zeros([2000])+ 0.1)
    L2 = tf.nn.sigmoid(tf.matmul(L1, W2) + bias2)
    W3 = tf.Variable(tf.truncated_normal([2000, 1000],stddev=0.1))
    bias3 = tf.Variable(tf.zeros([1000])+ 0.1)
    L3 = tf.nn.sigmoid(tf.matmul(L2, W3) + bias3)
    W4 = tf.Variable(tf.truncated_normal([1000, label_size],stddev=0.1))
    bias4 = tf.Variable(tf.zeros([label_size])+ 0.1)
    predict = tf.nn.softmax(tf.matmul(L3, W4) + bias4, name="predict")
    
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits = predict))
    opt = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    correct_prediction=tf.argmax(predict,1,"correct_prediction")
    correct_labels=tf.argmax(y,1, name="correct_labels")
    
    accuracy=tf.reduce_mean(tf.cast(tf.equal(correct_prediction, correct_labels),tf.float32), name="accuracy")
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init)
        feature_batch, label_batch, batch_num = generate_batch(feature, label, args.batch_size)
        for i in range(args.train_round):
            for batch_id in range(batch_num):
                _, l ,ty, ty_predict, oy, accuracy1 = sess.run([opt, loss, correct_labels,correct_prediction, y, accuracy], feed_dict={x:feature_batch[batch_id], y:label_batch[batch_id]})
                print("Epoch: {0} Batch: {1} Loss: {2} accuracy{3}".format(i, batch_id * args.batch_size,  l, accuracy1))
                #print("y:{0} predict:{1} true{2}".format(ty,ty_predict,oy))
        saver.save(sess, './model.ckpt')
        print("Done")

#保存模型並預估
def model_test(model_dir, test_feature, test_label):
    saver = tf.train.Saver()
    with tf.Session() as sess:
        ckpt = tf.train.latest_checkpoint(model_dir)
        saver.restore(sess, ckpt)
        graph = tf.get_default_graph()
        x=graph.get_operation_by_name('x').outputs[0]
        y=graph.get_operation_by_name('Y').outputs[0]
        accuracy = graph.get_operation_by_name('accuracy').outputs[0]
        correct_prediction = graph.get_operation_by_name('correct_prediction').outputs[0]
        correct_labels = graph.get_operation_by_name('correct_labels').outputs[0]
        ty, ty_predict, accuracy1 = sess.run([correct_labels,correct_prediction, accuracy], feed_dict={x:test_feature, y:test_label})
        #accuracy1 = sess.run([accuracy], feed_dict={x:test_feature, y:test_label})
        print(accuracy1)

if __name__ == '__main__':
    args = parse_arg()
    label, feature = read_data(args.train_data_dir) 
    model(args, label, feature)
    test_label, test_feature = read_data(args.test_data_dir) 
    model_test('./', test_feature, test_label)
#執行命令
python train_dnn.py  --train_data_dir=MNIST_data/mnist_train.csv --batch_size=10240 --train_round=100 --test_data_dir=MNIST_data/mnist_test.csv 

模型效果

60000條訓練樣本訓練100輪後在測試集合(10000條)準確率:0.7603
訓練集6w, 測試集1w

訓練輪數 準確率
100 0.7606
200 0.9234
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章