系列文章
引言
利用TensorFlow訓練好模型可以對測試集的數據進行預測,用於評估模型的好壞。但是每次執行一個預測任務時,均從頭訓練一下模型,則會耗費大量的時間與資源,並且有可能結果不能完全的復現。
所以TensorFlow提供了模型的保存與恢復的接口,當你訓練好模型後,可以將它持久化在本地,再次使用時可以直接恢復進行預測,不需要再重新訓練。
知識點
saver=tf.train.Saver()
定義一個保存模型的對象,也是固定寫法。
saver.save(sess, save_path=r"...")
寫在session中用於在每次迭代結束後保存一下模型,也可以保存效果最好的模型。
saver.restore(sess, save_path=r"...")
同樣寫在session中用於恢復模型的參數,其中save_path
就是模型保存的地址。
示例
#%% md
# 模型的保存與恢復
#%%
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#%% md
加載數據
#%%
mnist = input_data.read_data_sets("MNIST", one_hot=True)
#%% md
設置參數batch_size的大小,計算迭代的總批次
#%%
batch_size = 100
n_batches = mnist.train.num_examples // batch_size
#%% md
構建網絡
#%%
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
#%%
w = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
b = tf.Variable(tf.zeros([10]) + 0.1)
#%% md
預測輸出
#%%
prediction = tf.nn.softmax(tf.matmul(x, w) + b)
#%% md
定義損失函數
#%%
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
#%% md
定義優化器
#%%
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#%% md
計算正確率
#%%
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#%% md
定義初始化的init
#%%
init = tf.global_variables_initializer()
#%% md
定義保存的對象
#%%
saver = tf.train.Saver()
#%% md
訓練模型並保存
#%%
with tf.Session() as sess:
sess.run(init)
for epoch in range(100):
for batch in range(n_batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run([train_step], {x:batch_xs, y:batch_ys})
saver.save(sess, save_path='saved_model/mymodel')
acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
print("Iter: " + str(epoch) + " Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%% md
恢復模型進行對比
1. 未恢復參數的模型效果
2. 完全恢復模型參數的效果
#%%
with tf.Session() as sess:
sess.run(init)
acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%%
with tf.Session() as sess:
sess.run(init)
saver.restore(sess, 'saved_model/mymodel')
acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
#%%