TensorFlow保存還原模型的正確方式,Saver的save和restore方法,親測可用
許多TensorFlow初學者想把自己訓練的模型保存,並且還原繼續訓練或者用作測試。但是TensorFlow官網的介紹太不實用,網上的資料又不確定哪個是正確可行的。
今天David 9 就來帶大家手把手入門親測可用的TensorFlow保存還原模型的正確方式,使用的是網上最多的Saver的save和restore方法, 並且把關鍵點爲大家指出。
今天介紹最爲可行直接的方式來自這篇Stackoverflow:https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model 親測可用:
保存模型:
import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
必須強調的是:這裏4,5,6,11行中的name=’w1′, name=’w2′, name=’bias’, name=’op_to_restore’ 千萬不能省略,這是恢復還原模型的關鍵。
還原模型:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
還原當然是用restore方法,這裏的18,19,23行就是剛纔的name關鍵字指定的Tensor變量,必須找對才能進行還原恢復。
其他的關鍵在代碼和註釋中可以一眼看出, 這裏不加贅述了。
參考文獻:
- https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model
- https://nathanbrixius.wordpress.com/2016/05/24/checkpointing-and-reusing-tensorflow-models/
- https://stackoverflow.com/questions/42685994/how-to-get-a-tensorflow-op-by-name
- http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
- https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops/saving_and_restoring_variables