輸出TensorFlow中checkpoint內變量的幾種方法

將TF保存在checkpoint中的變量值輸出到txt或npy中的幾種簡單的可行的方法:

1,最簡單的方法,是在有model 的情況下,直接用tf.train.saver進行restore,就像 cifar10_eval.py 中那樣。然後,在sess中直接run變量的名字就可以得到變量保存的值。


在這裏以cifar10_eval.py爲例。首先,在Graph中穿件model。

    with tf.Graph().as_default() as g:  
        images, labels = cifar10.inputs(eval_data=eval_data)  
        logits = cifar10.inference(images)  
        top_k_op = tf.nn.in_top_k(logits, labels, 1)  

然後,通過tf.train.ExponentialMovingAverage.variable_to_restore確定需要restore的變量,默認情況下是model中所有trainable變量的movingaverge名字。並建立saver 對象

 variable_averages = tf.train.ExponentialMovingAverage(  
            cifar10.MOVING_AVERAGE_DECAY)  
        variables_to_restore = variable_averages.variables_to_restore()  
        saver = tf.train.Saver(variables_to_restore)  


variables_to_restore中是變量的movingaverage名字到變量的mapping(就是個字典)。我們可以打印嘗試打印裏面的變量名,

    for name in variables_to_restore:  
        print(name)  



輸出結果爲

    softmax_linear/biases/ExponentialMovingAverage



    conv2/biases/ExponentialMovingAverage



    local4/biases/ExponentialMovingAverage



    local3/biases/ExponentialMovingAverage



    softmax_linear/weights/ExponentialMovingAverage



    conv1/biases/ExponentialMovingAverage



    local4/weights/ExponentialMovingAverage



    local3/weights/ExponentialMovingAverage



    conv2/weights/ExponentialMovingAverage



    conv1/weights/ExponentialMovingAverage



然後在中通過run 變量名的方式就可以得到保存在checkpoint中的值,引文sess.run方法得到的是numpy形式的數據,就可以通過np.save或np.savetxt來保存了。


    with tf.Session() as sess:  
      ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)  
      if ckpt and ckpt.model_checkpoint_path:  
        # Restores from checkpoint  
        saver.restore(sess, ckpt.model_checkpoint_path)  
        conv1_w=sess.run('conv1/weights/ExponentialMovingAverage')  
 
此時conv1_w就是conv1/weights的MovingAverage的值,並且是numpy array的形式。




2, 第二種方法是使用
tensorflow/python/tools/inspect_checkpoint.py
 中提到的tf.train.NewCheckpointReader類

這種方法不需要model,只要有checkpoint文件就行。
首先用tf.train.NewCheckpointReader讀取checkpoint文件
如果沒有指定需要輸出的變量,則全部輸出,如果指定了,則可以輸出相應的變量
 
"""A simple script for inspect checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("file_name", "", "Checkpoint filename")
tf.app.flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")


def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.
  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.
  If `tensor_name` is provided, prints the content of the tensor.
  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    reader = tf.train.NewCheckpointReader(file_name)
    if not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")


def main(unused_argv):
  if not FLAGS.file_name:
    print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
          "[--tensor_name=tensor_to_print]")
    sys.exit(1)
  else:
    print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)

if __name__ == "__main__":
tf.app.run()


3,第三種方法也是TF官方在tool裏面給的,稱爲freeze_graph, 在官方的這個tutorials中有介紹。
一般情況下TF在訓練過程中會保存兩種文件,一種是保存了變量值的checkpoint文件,另一種是保存了模型的Graph(GraphDef)等其他信息的MetaDef文件,
以.meta結尾Meta,但是其中沒有保存變量的值。freeze_graph.py的主要功能就是將chenkpoint中的變量值保存到模型的GraphDef中,使得在一個文件中既
包含了模型的Graph,又有各個變量的值,便於後續操作。當然變量值的保存是可以有選擇性的。
在freeze_graph.py中,首先是導入GraphDef (如果有GraphDef則可之間導入,如果沒有,則可以從MetaDef中導入). 然後是從GraphDef中的所有nodes中
抽取主模型的nodes(比如各個變量,激活層等)。再用saver從checkpoint中恢復變量的值,以constant的形式保存到抽取的Grap的nodes中,並輸出此GraphDef.
GraphDef 和MetaDef都是 基於Google Protocol Buffer 定義的。在GraphDef 中主要以node(NodeDef) 來保存模型。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章