Tensorflow_03_Save and Restore 儲存和載入

Brief 概述

在理解了建構神經網絡的大致函數用途,且熟悉了神經網絡原理後,我們已經大致具備可以編寫神經網絡的能力了,在涉及比較複雜的神經網絡結構前,還有一件重要的事情需要了解,那就是中途存檔和事後讀取的函數,它攸關到龐大的算力和時間投入後產出的結果是否能夠被再次使用,是一個絕對必須弄清楚的環節,因此本節主要圍繞一個主題:

  • Checkpoint 檢查點

它如同會議記錄一般,可以針對性的把訓練過程記錄下來,除了避免前功盡棄之外,還可以讓我們有機會一窺訓練過程的究竟,從演變過程中尋找改善算法的方案。

p.s. 關於設備如果手邊沒有,非常建議直接使用雲端的計算服務,如 AWS, FloydHub 等平臺

其他在深度學習中常用的函數定義方法可以參考上一篇文章: Tensorflow_02_Useful Functions 常用函數大全

Checkpoint 檢查點

在初期一般訓練模型簡單且訓練速度極快,對於參數中間變化的過程我們也不會特別在意,但是到了複雜的神經網絡訓練過程時,爲參數訓練過程中途存檔這件事情就會變得非常重要,這就像我們玩電玩遊戲闖關的時候,希望最好能夠中途存檔,如果死在半路上可以直接從存檔的地方恢復遊戲。

 

Save checkpoints 儲存檢查點

同理深度學習訓練過程,一般訓練耗費時間約爲幾天乃至一週,如果中途發生機器停機或是任何意外導致訓練終止,我們可以從檢查點記錄的地方重新開始。抑或者如果我們要分析訓練過程中參數的變化走勢,檢查點也非常實用。使用的類爲:

  • tf.train.Saver(max_to_keep=None) 檔名: 「.ckpt」
  • .Saver({’save_w‘: weight}) 括弧中可以用字典的方式指定只要儲存哪一個參數
  • max_to_keep=None: 最多有幾個檢查點被保存下來,如果是 None 或是 0 則表示全保存
  • keep_checkpoint_every_n_hours=1: 設置幾個小時保存一次檢查點

變量以二進制的方式被存在名爲 .ckpt 的檔案中,內容包含了變量的名字和對應張量的數值,創建一個該類的示例,就可以呼叫裏面儲存與載入儲存文件內容的函數方法:

  • tf.train.Saver().save(sess, './file_directory', global_step=int(num))
  • sess: 表示要儲存哪個繪話裏面的參數
  • './file_directory/file_name': 儲存的路徑沿着執行訓練的 .py 文檔路徑位置繼續指定路徑,如果文件夾不存在指定目錄的話,它會自行創建。官網教程中建議檔名後面連同後綴一起加上,如下代碼...
  • global_step:指定一個數字,將一起被納入檢查點文件命名中

!!! 儲存這些參數的時候特別需要注意申明清楚參數的數據類型非常重要,它攸關到之後要呼叫回這些參數的時候是否順利,如果沒有事先申明清楚,大概率上會有錯誤發生。

下面代碼展示如何保存檢查點:

import numpy as np
import tensorflow as tf

x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

weight = tf.Variable(tf.random_uniform(shape=[1], minval=-1.0, maxval=1.0), 
                     dtype=np.float32)#, name='weight')
bias = tf.Variable(tf.zeros(shape=[1]), dtype=np.float32, name='bias')
y = weight * x_data + bias

loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
training = optimizer.minimize(loss)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# The instance is created to call the method saving checkpoint
saver = tf.train.Saver()
save_w = tf.train.Saver({'a_name': weight})

for step in range(101):
    sess.run(training)
    if step % 10 == 0:
        print('Round {}, weight: {}, bias: {}'
              .format(step, sess.run(weight[0]), sess.run(bias[0])))
        saver.save(sess, './checkpoint/linear.ckpt', global_step=step)
        save_w.save(sess, './weight/linear.ckpt', global_step=step)
        
saver.save(sess, './checkpoint/linear.ckpt')
sess.close()


### ----- Result is shown below ----- ###
Round 0, weight: 0.6087742447853088, bias: 0.031045857816934586
Round 10, weight: 0.3177388906478882, bias: 0.18408644199371338
Round 20, weight: 0.19332920014858246, bias: 0.2503160834312439
Round 30, weight: 0.14000359177589417, bias: 0.27870404720306396
Round 40, weight: 0.11714668571949005, bias: 0.2908719480037689
Round 50, weight: 0.10734956711530685, bias: 0.29608744382858276
Round 60, weight: 0.10315024852752686, bias: 0.29832297563552856
Round 70, weight: 0.10135028511285782, bias: 0.29928117990493774
Round 80, weight: 0.10057878494262695, bias: 0.29969191551208496
Round 90, weight: 0.10024808347225189, bias: 0.2998679280281067
Round 100, weight: 0.10010634362697601, bias: 0.2999434173107147

檢查點的路徑設置需要使用 「./.../.../...」 的格式去寫路徑,尤其是開頭的 ./ 必須加上,否則在某些平臺上會出現錯誤,等代碼運行完畢後在下面 .py 文檔執行路徑下出現我們設置的儲存文件夾和文件名稱,如下圖:

在默認情況下 tf.train.Saver(max_to_keep=5) 是我們無特別設定的結果,因此只會保存離最近更新的五個參數,其他的參數將即自動刪除。

 

Read checkpoints 讀取檢查點

文件存好之後接下來就是讀取上圖中儲存的文件,儲存在文件裏面的數據是一個原封不動的 tf.Variable() 物件,有着與儲存前一模一樣的名字和屬性,甚至在呼叫回該儲存的變量時也不用初始化,是一個非常全面的保存結果, 只是需要記得: 「同樣變量名的物件需要事先存在在代碼中, 並且數據類型和長相必須一模一樣。

讀取的方式也很直觀,同樣的創建一個 tf.train.Saver() 示例,並用該示例裏面的方法 .restore() 完成讀取,讀取完畢後儲存的參數就回像起死回生一般重新回到我們的代碼中。

  • tf.train.Saver().restore(sess, 'file_directory')
  • sess: 表示我們希望把該儲存的內容重新叫回哪一個繪話中
  • './file_directory/file_name': 表示我們要呼叫的該存檔文件

p.s. 如果在儲存過程中有加上 global_step 參數,呼叫文檔名的時候就必須一起把數字也加上去,如下代碼。

呼叫儲存文件的時候有以下三種情況:

  1. 最直接: 使用 tf.train.Saver() 創建示例後,呼叫 .restore() 方法配合對應名字,成功回到訓練中途的記錄
  2. 第一個方法受阻: 繞道使用 .meta 儲存文件,並使用 tf.import_meta_graph() 示例的 .restore() 方法,同樣可以成功回到訓練中途的記錄
  3. 呼叫只儲存部分參數的記錄檔: 創建一個示例前先在 tf.train.Saver() 括弧中使用字典形式聲明好當時部分儲存的時候對應一模一樣名字的字典鍵和參數名,再用 .restore() 方法成功回到訓練中途的記錄

詳細代碼如下演示:

import tensorflow as tf

# tf.reset_default_graph()
weight = tf.Variable([33], dtype=tf.float32)#, name='weight')
bias = tf.Variable([3], dtype=tf.float32, name='bias')

saver = tf.train.Saver()
# saver = tf.train.import_meta_graph('./checkpoint/linear.ckpt.meta')
saver_2 = tf.train.Saver({'a_name': weight})
init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)
path1 = saver.restore(sess, './checkpoint/linear.ckpt-90')
path2 = saver_2.restore(sess, './weight/linear.ckpt-60')
print(sess.run(weight))
print(sess.run(bias))
sess.close()

''' 
print(sess.run(biases))

### ----- Result as follow ----- ###
FailedPreconditionError: 
Attempting to use uninitialized value Variable
[[Node: _retval_Variable_0_0 = _Retval[T=DT_FLOAT, index=0, 
  _device="/job:localhost/replica:0/task:0/device:CPU:0"](Variable)]]
'''


### ----- Result is shown below ----- ###
INFO:tensorflow:Restoring parameters from ./checkpoint/linear.ckpt-90
INFO:tensorflow:Restoring parameters from ./weight/linear.ckpt-60
[0.10315025]
[0.29986793]

可以觀察到,如果沒有成功導入內容, sess.run() 執行一個參數的時候就會被通知該參數沒有初始化,需要特別注意。另外如果重複導入同樣的值到該代碼中,那麼該值以最後一次導入爲主,如上面代碼中的 weight,最近導入的 60 個回合訓練的 weight 值比訓練 90 個回合的 bias 值還要不準得多。

 

  • tf.train.latest_checkpoint('./.../...')
  • more to update

 

!! 重要 !!  導入沒有成功, 報錯 >> ValueError: At least two variables have the same name: Variable

花了一整個晚上找方法的錯誤,原因還是在於 tf.Variable() 的格式沒有完全一樣,前面只專注在數據格式上面,但是其節點名稱必須也完全一致纔可以! 如果表明名稱 name='a_name', 那麼就都不要寫,如果表明了名稱,那就必須完全一致才行!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章