TensorFlow1.x入門(11)——模型的保存與恢復

系列文章

1. 計算圖的創建與啓動

2. 變量的定義及其操作

3. Feed與Fetch

4. 線性迴歸

5. 構建非線性迴歸模型

6. 簡單分類問題

7. Dropout與優化器

8. 手動調整學習率與TensorBoard

9. 卷積神經網絡(CNN)

10. 循環神經網絡(RNN)

模型的保存與恢復

引言

利用TensorFlow訓練好模型可以對測試集的數據進行預測,用於評估模型的好壞。但是每次執行一個預測任務時,均從頭訓練一下模型,則會耗費大量的時間與資源,並且有可能結果不能完全的復現。
所以TensorFlow提供了模型的保存與恢復的接口,當你訓練好模型後,可以將它持久化在本地,再次使用時可以直接恢復進行預測,不需要再重新訓練。

知識點

saver=tf.train.Saver()定義一個保存模型的對象,也是固定寫法。
saver.save(sess, save_path=r"...")寫在session中用於在每次迭代結束後保存一下模型,也可以保存效果最好的模型。
saver.restore(sess, save_path=r"...")同樣寫在session中用於恢復模型的參數,其中save_path就是模型保存的地址。

示例

#%% md
# 模型的保存與恢復
#%%
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#%% md
加載數據
#%%
mnist = input_data.read_data_sets("MNIST", one_hot=True)
#%% md
設置參數batch_size的大小,計算迭代的總批次
#%%
batch_size = 100
n_batches = mnist.train.num_examples // batch_size
#%% md
構建網絡
#%%
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
#%%
w = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
b = tf.Variable(tf.zeros([10]) + 0.1)
#%% md
預測輸出
#%%
prediction = tf.nn.softmax(tf.matmul(x, w) + b)

#%% md
定義損失函數
#%%
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#%% md
定義優化器
#%%
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#%% md
計算正確率
#%%
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#%% md
定義初始化的init
#%%
init = tf.global_variables_initializer()
#%% md
定義保存的對象
#%%
saver = tf.train.Saver()
#%% md
訓練模型並保存
#%%
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(100):
        for batch in range(n_batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run([train_step], {x:batch_xs, y:batch_ys})
        saver.save(sess, save_path='saved_model/mymodel')
        acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
        print("Iter: " + str(epoch) + " Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%% md
恢復模型進行對比
1. 未恢復參數的模型效果
2. 完全恢復模型參數的效果
#%%
with tf.Session() as sess:
    sess.run(init)
    acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
    print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%%
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, 'saved_model/mymodel')
    acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
    print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%%

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章