一份快速完整的Tensorflow模型保存和恢復教程(譯)

原文鏈接A quick complete tutorial to save and restore Tensorflow models–by ANKIT SACHAN
(英文水平有限,有翻譯不當的地方請見諒)

在本教程中,我將介紹:
- tensorflow模型是什麼樣子的?
- 如何保存一個Tensorflow模型?
- 如何恢復一個Tensorflow模型用於預測/遷移學習?
- 如何導入預訓練的模型進行微調和修改?

本教程假設你已經對訓練一個神經網絡有一定了解。否則請先看這篇教程Tensorflow Tutorial 2: image classifier using convolutional neural network再看本教程。

什麼是Tensorflow模型?

當你訓練好一個神經網絡後,你會想保存好你的模型便於以後使用並且用於生產。因此,什麼是Tensorflow模型?Tensorflow模型主要包含網絡設計(或者網絡圖)和訓練好的網絡參數的值。所以Tensorflow模型有兩個主要的文件:

a) Meta圖:
Meta圖是一個協議緩衝區(protocol buffer),它保存了完整的Tensorflow圖;比如所有的變量、運算、集合等。這個文件的擴展名是.meta

b) Checkpoint 文件
這是一個二進制文件,它保存了權重、偏置項、梯度以及其他所有的變量的取值,擴展名爲.ckpt。但是, 從0.11版本開始,Tensorflow對改文件做了點修改,checkpoint文件不再是單個.ckpt文件,而是如下兩個文件:

mymodel.data-00000-of-00001
mymodel.index

其中, .data文件包含了我們的訓練變量。除此之外,還有一個叫checkpoint的文件,它保留了最新的checkpoint文件的記錄。

總結一下,對於0.10之後的版本,tensorflow模型包含以下文件:

model files
但對於0.11之前的版本,只包含三個文件:

inception_v1.meta
inception_v1.ckpt
checkpoin

現在我們已經知道Tensorflow模型是什麼樣子的,讓我們繼續學習如何保存模型。

保存Tensorflow模型

假如你正在訓練一個用於圖像分類的卷積神經網絡(training a convolutional neural network for image classification)。通常你會先觀察損失和準確率,一旦發現網絡收斂,就可以手動停止訓練過程或者直接訓練固定迭代次數。當訓練完成後,我們想要保存所有的變量和網絡圖便於以後使用。因此在Tensorflow中, 爲了保存網絡圖和所有參數的值,我們應該創建tf.train.Saver()這個類的一個對象。

saver = tf.train.Saver()

記住Tensorflow變量只有在會話(session)中才能激活。因此,你需要在會話中調用你剛創建的對象的保存方法。

saver.save(sess, "my-test-model")

這裏,sess是一個session對象,“my-test-model”是你的模型名字。讓我們看一個完整的例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

如果我們要在1000次迭代後保存模型,我們應該在調用保存方法時傳入步數計數:

saver.save(sess, "my_test_model", global_step=1000)

這會在模型名稱後加一個“-1000”並且會創建如下文件:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

假設在訓練過程中,我們要每1000次迭代保存我們的模型,因此.meta文件會在第一次(1000次迭代)時創建,我們並不需要之後每1000次迭代都保存一遍這個文件(我們在2000,3000…迭代時都不需要保存這個文件,因爲這個文件始終不變)。我們只需要保存這個模型供以後使用,因爲模型圖不會變化。所以,當我們不想重寫meta圖的時候,我們這樣寫:

saver.save(sess, "my-model", global_step=step, write_meta_graph=False)

如果你只想保留4個最新的模型並且在訓練過程中每過2小時保存一次模型,你可以使用max_to_keep和keep_checkpoint_every_n_hours,就像這樣:

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

注意,如果我們在tf.train.Saver()中不指定任何東西,它將保存所有的變量。要是我們不想保存所有的變量而只是一部分變量。我們可以指定我們想要保存的變量/集合。當創建tf.train.Saver()對象的時候,我們給它傳遞一個我們想要保存的變量的字典列表。我們來看一個例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)

當需要的時候,這個代碼可以用來保存Tensorflow圖中的指定部分。

導入預訓練模型

如果你想要用其他人預訓練的模型進行微調,需要做兩件事:

a) 創建網絡
你可以寫python代碼來手動創建和原來一樣的模型。但是,想想看,我們已經將原始網絡保存在了.meta文件中,可以用tf.train.import()函數來重建網絡:

saver = tf.train.import_meta_graph("my_test_model-1000.meta")

記住,import_meta_graph函數將只將定義在.meta文件中的網絡添加到當前的圖上。因此,它雖然幫你創建了額圖/網絡,但我們還是需要導入我們在這個圖上訓練好的模型的參數。

b) 導入參數
我們可以調用由tf.train.Saver()創建的對象saver中的restore方法來恢復網絡中的參數。

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./'))

這樣,張量的值(如w1和w2)就被恢復並且可以訪問了:

with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('my-model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value of w1.

現在你已經理解了如何保存和導入Tensorflow模型。在下一節,我會介紹一個實際應用即導入任何預訓練好的模型。

使用恢復的模型

現在你已經理解如何保存和恢復Tensorflow模型,我們來寫一個實際的示例來恢復任何預訓練的模型並用它來預測、微調或者進一步訓練。無論你什麼時候用Tensorflow,你都會定義一個網絡,它有一些樣本(訓練數據)和超參數(如學習率、迭代次數等)。通常用一個佔位符(placeholder)來將所有的訓練數據和超參數輸入給網絡。下面我們用佔位符建立一個小型網絡並保存它。注意,當網絡被保存的時候,佔位符中的值並沒有被保存。

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

當我們想要恢復這個網絡的時候,我們不僅需要恢復圖和權重,還需要準備一個新的feed_dict來將訓練數據輸入到網絡中。我們可以通過graph.get_tensor_by_name方法來引用這些保存的運算和佔位符變量。

#How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")

## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

如果我們只是想用不同的數據運行相同的網絡,你可以方便地用feed_dict將新的數據送到網絡中。

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1. 

要是你想在原來的計算圖中通過添加更多的層來增加更多的運算並且訓練。當然也可以實現,如下:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

但是,我們能夠只恢復原來圖中的一部分然後添加一些其它層來微調嗎?當然可以,只要通過graph.get_tensor_by_name()方法來獲取原網絡的部分計算圖並在上面繼續建立新計算圖。這裏給出了一個實際的例子。我們用meta圖導入了一個預訓練的vgg網絡,然後將最後一層的輸出個數改成2用於微調新的數據。

......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

num_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()

希望本文能夠讓你清楚地理解Tensorflow是如何被保存和微調的。請在評論區自由分享你的問題或者疑問。

另外,爲了便於理解,我上傳了一份用MNIST數據集訓練及調用模型的例子,見鏈接:https://pan.baidu.com/s/1C-l3YZGbEsAFIClgSQN46Q 密碼:3iq8

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章