【AI實戰】快速掌握TensorFlow(二):計算圖、會話 原


在前面的文章中,我們已經完成了AI基礎環境的搭建(見文章:Ubuntu + Anaconda + TensorFlow + GPU + PyCharm搭建AI基礎環境),以及初步瞭解了TensorFlow的特點和基本操作(見文章:快速掌握TensorFlow(一)),接下來將繼續學習掌握TensorFlow。

本文主要是學習掌握TensorFlow的計算圖、會話操作。

 

 
計算圖是TensorFlow的核心概念,使用圖(Graph)來表示計算任務,由節點和邊組成。TensorFlow由前端負責構建計算圖,後端負責執行計算圖。
爲了執行圖的計算,圖必須在會話(Session)裏面啓動,會話將圖的操作分發到CPU、GPU等設備上執行。
下面將介紹如何在TensorFlow裏面創建會話、圖以及基本操作。

1、圖(Graph)
TensorFlow Python庫已經有一個默認圖 (default graph),如果沒有創建新的計算圖,則默認情況下是在這個default graph裏面創建節點和邊。
在圖裏面添加節點非常方便。例如現在要創建這樣的計算圖,兩個張量相加,如下圖:
 
代碼如下:

import tensorflow as tf 
a=tf.constant([1.0,2.0], name='a') 
b=tf.constant([3.0,4.0], name='b') 
result = tf.add(a,b)

現在默認圖就有了三個節點,兩個constant(),和一個add()。
爲了真正使兩個張量相加並得到結果,就必須在會話裏面啓動這個圖。

2、會話(Session)
要啓動計算圖,首先要創建一個Session對象。
使用tf.Session()創建會話,調用run()函數執行計算圖。如果沒有傳入任何創建參數,會話構造器將啓動默認圖。如果要指定某個計算圖,則傳入計算圖參數(如g1),則創建會話方式爲tf.Session(graph=g1)創建會話(Session)主要有以下三種方式:
(1)創建一個會話

#啓動默認圖
sess=tf.Session()
result_value = sess.run(result)
print(result_value)
# ==> [4.0 6.0]

# 任務完成, 關閉會話.
sess.close()

(2) 創建一個會話
Session在使用完後需要關閉以釋放資源,除了顯式調用close外,也可以使用“with”代碼塊 來自動完成關閉動作。代碼如下:

with tf.Session() as sess:
    result_value = sess.run(result)
    print(result_value)
    # ==> [4.0 6.0]

(3)創建一個默認的會話

sess=tf.Session()
with sess.as_default():
    result_value = result.eval()
    print(result_value)

當指定默認會話後,可以通過tf.Tensor.eval函數來計算一個張量的取值。

(4)創建一個交互式會話
在交互式環境下(例如IPython),使用設置默認會話的方式來獲取張量的取值更加方便,TensorFlow提供了一種在交互式環境下直接構建默認會話的函數:tf.InteractiveSession,該函數會自動將生成的會話註冊爲默認會話,使用 tf.Tensor.eval()代替 Session.run(),代碼如下:

sess= tf.InteractiveSession()
result_value = result.eval()
print(result_value)
sess.close()

3、構建多個計算圖
在TensorFlow中可以構建多個計算圖,計算圖之間的張量和運算是不會共享的,通過這種方式,可以在同個項目中構建多個網絡模型,而相互之間不會受影響。
使用tf.Graph()函數構建圖,構建多個計算圖的方式如下:

# 構建計算圖g1
g1=tf.Graph()
with g1.as_default():
    # 在計算圖g1中定義變量'v',並設置初始值爲0。
    v=tf.get_variable('v',initializer=tf.zeros_initializer()(shape = [1]))
    
# 構建計算圖g2
g2=tf.Graph()
with g2.as_default():
    # 在計算圖g2中定義變量'v',並設置初始值微1。
    v=tf.get_variable('v',initializer=tf.ones_initializer()(shape = [1]))

# 在計算圖g1中讀取變量'v'的取值
with tf.Session(graph=g1) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope('',reuse=True):
        print(sess.run(tf.get_variable('v')))
        # 輸出結果[0.]

# 在計算圖g2中讀取變量'v'的取值
with tf.Session(graph=g2) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope('',reuse=True):
        print(sess.run(tf.get_variable('v')))
        # 輸出結果[1.]。

4、指定運行設備
如果電腦有多個GPU,可以在圖、會話中指定要運行的設備
(1)在圖中指定運行設備

g=tf.Graph()
# 指定計算運行的設備。
with g.device('/gpu:0'):
    result=tf.add(a,b)

(2)在會話中指定運行設備

with tf.Session() as sess:
  with tf.device("/gpu:0"):
    result=tf.add(a,b)

運行的設備用字符串進行標識,目前支持的設備包括:

  • "/cpu:0": 機器的 CPU
  • "/gpu:0": 機器的第一個 GPU,如果有的話
  • "/gpu:1": 機器的第二個 GPU,以此類推

通過以上介紹,已經瞭解了圖、會話的基本操作,使用圖 (graph) 來表示計算任務,使用會話 (Session) 來執行圖。

接下來,我們將有更多講解TensorFlow的精彩內容,敬請期待!

 

推薦相關閱讀

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章