TensorFlow進一步優化神經網絡

本文首發於我的個人博客QIMING.INFO,轉載請帶上鍊接及署名。

在本站的這篇文章《TensorFlow實現簡單神經網絡》中,我們用TensorFlow實現了對MINST手寫數字集的分類,分類的準確率達到了92%,本文中將優化此神經網絡,將準確率提升至98%。

1 優化思路

對神經網絡進行優化時,可以採取的思路主要有以下幾種:

  • 合適的損失函數
  • 合適的激活函數
  • 合適的優化器
  • 神經網絡的層數
  • 學習率的設置
  • 處理過擬合問題
  • 增大訓練樣本量、訓練輪次

本例中,交叉熵函數比二次代價函數更適合作爲損失函數,激活函數採用了tanh()函數,優化器選用了Adam函數。

神經網絡的層數並不是越多越好(太複雜的神經網絡解決數據量較小的問題極易出現過擬合現象),本例中設置了兩層中間層。

設置學習率時,學習率太大會導致參數的值不停搖擺,而不會收斂到一個極小值,太小又會大大降低優化速度,所以我們可以先使用一個較大的學習率來快速得到一個比較優的解,然後隨着迭代的繼續逐步減小學習率,使得模型在訓練後期更加穩定。

爲防止過擬合問題,本例中使用了dropout機制。

在深度學習中,增大訓練樣本量可以使很多問題迎刃而解,但在本例中並不適用,因爲本例已經使用了MNIST的全部訓練數據。但是可以增加訓練輪次,本例中將上文的21次提升到了51次。

好了,來敲敲代碼看療效吧~

2 代碼及說明

import tensorflow as tf
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
# 載入數據集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

# 每個批次的大小
batch_size = 100
# 計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

# 定義placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
# 定義dropout
keep_prob = tf.placeholder(tf.float32)
# 定義一個可變的學習率變量
lr = tf.Variable(0.001,dtype=tf.float32)

# 創建神經網絡
# 設置第一層中間層的節點數爲1000個
W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
b1 = tf.Variable(tf.zeros([1000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob)

# 設置第二層中間層的節點數爲500個
W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
b2 = tf.Variable(tf.zeros([500])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob)

# 輸出層
W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)

# 交叉熵代價函數
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
# 使用Adam作爲優化器進行訓練
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

# 初始化變量
init = tf.global_variables_initializer()

# 結果存放在一個布爾型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) # argmax返回一維張量中最大的值所在的位置
# 求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(51):
        # 每訓練一輪 學習率降低 
        sess.run(tf.assign(lr,0.001 * (0.95 ** epoch)))
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})
            
        # 計算測試數據的準確率
        test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0 })
        # 計算訓練數據的準確率
        train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
        # 輸出訓練輪次、測試數據準確率、訓練數據準確率
        print("Iter "+str(epoch)+",Testing Accuracy "+str(test_acc)+",Training Accuracy " + str(train_acc) )

3 結果

Iter 0,Testing Accuracy 0.9439,Training Accuracy 0.9438
Iter 1,Testing Accuracy 0.9515,Training Accuracy 0.9538364
Iter 2,Testing Accuracy 0.9582,Training Accuracy 0.96207273
Iter 3,Testing Accuracy 0.9616,Training Accuracy 0.9679273
Iter 4,Testing Accuracy 0.9659,Training Accuracy 0.9701818
Iter 5,Testing Accuracy 0.9668,Training Accuracy 0.9737818
Iter 6,Testing Accuracy 0.9691,Training Accuracy 0.9764364
Iter 7,Testing Accuracy 0.9718,Training Accuracy 0.979
Iter 8,Testing Accuracy 0.9707,Training Accuracy 0.9800364
Iter 9,Testing Accuracy 0.9716,Training Accuracy 0.98210907
Iter 10,Testing Accuracy 0.9744,Training Accuracy 0.9829818
Iter 11,Testing Accuracy 0.973,Training Accuracy 0.98376364
Iter 12,Testing Accuracy 0.9743,Training Accuracy 0.9856
Iter 13,Testing Accuracy 0.9749,Training Accuracy 0.9863091
Iter 14,Testing Accuracy 0.9755,Training Accuracy 0.9862546
Iter 15,Testing Accuracy 0.974,Training Accuracy 0.98661816
Iter 16,Testing Accuracy 0.9763,Training Accuracy 0.9874
Iter 17,Testing Accuracy 0.9751,Training Accuracy 0.9886909
Iter 18,Testing Accuracy 0.9768,Training Accuracy 0.98914546
Iter 19,Testing Accuracy 0.9756,Training Accuracy 0.98987275
Iter 20,Testing Accuracy 0.9766,Training Accuracy 0.9896182
Iter 21,Testing Accuracy 0.9771,Training Accuracy 0.9906545
Iter 22,Testing Accuracy 0.9786,Training Accuracy 0.9912364
Iter 23,Testing Accuracy 0.9781,Training Accuracy 0.99152726
Iter 24,Testing Accuracy 0.9782,Training Accuracy 0.9915636
Iter 25,Testing Accuracy 0.9778,Training Accuracy 0.9921273
Iter 26,Testing Accuracy 0.9799,Training Accuracy 0.99243635
Iter 27,Testing Accuracy 0.979,Training Accuracy 0.99258184
Iter 28,Testing Accuracy 0.9798,Training Accuracy 0.99285454
Iter 29,Testing Accuracy 0.9784,Training Accuracy 0.99294543
Iter 30,Testing Accuracy 0.9789,Training Accuracy 0.99307275
Iter 31,Testing Accuracy 0.9794,Training Accuracy 0.99325454
Iter 32,Testing Accuracy 0.9786,Training Accuracy 0.9934545
Iter 33,Testing Accuracy 0.9791,Training Accuracy 0.9937818
Iter 34,Testing Accuracy 0.9797,Training Accuracy 0.9938545
Iter 35,Testing Accuracy 0.9799,Training Accuracy 0.9941273
Iter 36,Testing Accuracy 0.9802,Training Accuracy 0.99407274
Iter 37,Testing Accuracy 0.9807,Training Accuracy 0.99438184
Iter 38,Testing Accuracy 0.9814,Training Accuracy 0.9944182
Iter 39,Testing Accuracy 0.9805,Training Accuracy 0.99447274
Iter 40,Testing Accuracy 0.9809,Training Accuracy 0.9945091
Iter 41,Testing Accuracy 0.9813,Training Accuracy 0.9946182
Iter 42,Testing Accuracy 0.9811,Training Accuracy 0.99474543
Iter 43,Testing Accuracy 0.9809,Training Accuracy 0.9948364
Iter 44,Testing Accuracy 0.9812,Training Accuracy 0.99485457
Iter 45,Testing Accuracy 0.9814,Training Accuracy 0.99487275
Iter 46,Testing Accuracy 0.9824,Training Accuracy 0.9948909
Iter 47,Testing Accuracy 0.9817,Training Accuracy 0.9950182
Iter 48,Testing Accuracy 0.982,Training Accuracy 0.9950909
Iter 49,Testing Accuracy 0.9821,Training Accuracy 0.9951091
Iter 50,Testing Accuracy 0.982,Training Accuracy 0.9951091

可以看出,在訓練了51輪後,測試數據的準確率已經達到了98.2%,訓練數據的準確率達到了99.5% 。

4 參考資料

[1]@Bilibili.深度學習框架Tensorflow學習與應用.2018-03
[2]鄭澤宇,樑博文,顧思宇.TensorFlow:實戰Goole深度學習框架(第2版)[M].北京:電子工業出版社.2018-02

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章