從Tensorflow模型文件中解析並顯示網絡結構圖(CKPT模型篇) 1 解析CKPT網絡結構 2 自動將CKPT轉pb,並提取網絡圖中節點 3 測試 4 源碼地址

最近看到一個巨牛的人工智能教程,分享一下給大家。教程不僅是零基礎,通俗易懂,而且非常風趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。平時碎片時間可以當小說看,【點這裏可以去膜拜一下大神的“小說”】

上一篇文章《從Tensorflow模型文件中解析並顯示網絡結構圖(pb模型篇)》中介紹瞭如何從pb模型文件中提取網絡結構圖並實現可視化,本文介紹如何從CKPT模型文件中提取網絡結構圖並實現可視化。理論上,既然能從pb模型文件中提取網絡結構圖,CKPT模型文件自然也不是問題,但是其中會有一些問題。

1 解析CKPT網絡結構

解析CKPT網絡結構的第一步是讀取CKPT模型中的圖文件,得到圖的Graph對象後即可得到完整的網絡結構。讀取圖文件示例代碼如下所示。

    saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
    graph = tf.get_default_graph()
    with tf.Session( graph=graph) as sess:
        sess.run(tf.global_variables_initializer()) 
        saver.restore(sess,ckpt_path) 

調用graph.get_operations()後即可得到當前圖的所有計算節點,在利用Operation對象與Tensor對象之間的相互引用關係即可推斷網絡結構。但是需要注意的是,從meta文件中導入的圖中獲取計算節點存在如下問題。

包含反向梯度下降計算的所有節點
某些計算節點是按基礎計算(加減乘除等)節點拆分成多個計算節點的,如BatchNorm,但其實是可以直接合併成一個節點的。

pb模型文件可以避免上面第一個問題,將CKPT模型轉pb模型後,可以自動將反向梯度下降相關計算節點移除。對於第二點,pb模型文件會自動將基礎計算組成一個計算節點,但是對於Tensor操作的函數如Slice等函數是無法合併的。因此,對於第2個問題,將CKPT模型轉pb模型後,可以減少這類問題,但是無法避免。徹底避免的方法只能通過自己針對性地實現。經過以上分析,得出的結論是非常有必要將CKPT模型轉pb模型。

2 自動將CKPT轉pb,並提取網絡圖中節點

如果將CKPT自動轉pb模型,那麼就可以複用上一篇文章《從Tensorflow模型文件中解析並顯示網絡結構圖(pb模型篇)》的代碼。示例代碼如下所示。

def read_graph_from_ckpt(ckpt_path,input_names,output_name ):   
    saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
    graph = tf.get_default_graph()
    with tf.Session( graph=graph) as sess:
        sess.run(tf.global_variables_initializer()) 
        saver.restore(sess,ckpt_path) 
        output_tf =graph.get_tensor_by_name(output_name) 
        pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name]) 
     
    with tf.Graph().as_default() as g:
        tf.import_graph_def(pb_graph, name='')  
    with tf.Session(graph=g) as sess:
        OPS=get_ops_from_pb(g,input_names,output_name)
    return OPS

其中函數get_ops_from_pb在上一篇文章《從Tensorflow模型文件中解析並顯示網絡結構圖(pb模型篇)》中已經實現。

3 測試

《MobileNet V1官方預訓練模型的使用》文中介紹的MobileNet V1網絡結構爲例,下載MobileNet_v1_1.0_192文件並壓縮後,得到mobilenet_v1_1.0_192.ckpt.data-00000-of-00001mobilenet_v1_1.0_192.ckpt.indexmobilenet_v1_1.0_192.ckpt.meta文件。我們還需要知道mobilenet_v1_1.0_192.ckpt模型對應的輸入和輸出Tensor對象的名稱,官方提供的壓縮包文件中並沒有告知。一種方法是運行官方代碼,把輸入Tensor的名稱打印出來。但是運行官方代碼本身就需要一定的時間和精力,在在上一篇文章《從Tensorflow模型文件中解析並顯示網絡結構圖(pb模型篇)》的代碼實現中已經實現了將原始網絡結構對應的字符串寫入到ori_network.txt文件中。因此,可以先隨意填寫輸入名稱和輸出名稱,待生成ori_network.txt文件後,從文件中可以直觀看到原始網絡結構。ori_network.txt文件部分內容如下所示。

通過該文件可知,輸入Tensor的名稱爲:batch:0,輸出Tensor名稱爲:MobilenetV1/Predictions/Reshape_1:0。有了這些信息後,調用函數read_graph_from_ckpt得到靜態圖的節點列表對象ops,調用函數gen_graph(ops,"save/path/graph.html")後,在目錄save/path中得到graph.html文件,打開graph.html後,顯示結果如下。

4 源碼地址

https://github.com/huachao1001/CNNGraph

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章