tf.get_variable()和tf.Variable()的區別

Variable

最近在學習TensorFlow的過程中, 看到在定義變量的時候有兩種操作:
tf.get_variable()和tf.Variable()。

def weight_variable(shape):
    #initial = tf.truncated_normal(shape, stddev=0.1)
    #return tf.Variable(initial) # tf.get_variable()
    return tf.get_variable(name="w", shape=shape,
                           initializer=tf.truncated_normal_initializer(mean=0.0,
                                                                       stddev=0.1,
                                                                       seed=None,
                                                                       dtype=tf.float32))


with tf.Session() as sess:
    w1 = weight_variable([3, 3, 3, 1, 1])
    sess.run(tf.initialize_all_variables())
    print(sess.run(w1))

本以爲兩者沒什麼區別,但是博士師兄建議使用tf.get_variable()定義,不解查閱,於是總結了一下兩者區別,如下:

tf.Variable()

tf.Variable(initial_value=None, 
            trainable=True, 
            collections=None, 
            validate_shape=True, 
            caching_device=None, 
            name=None, variable_def=None, 
            dtype=None, expected_shape=None, 
            import_scope=None)

tf.get_variable()

tf.get_variable(name, shape=None, 
                dtype=None, initializer=None, 
                regularizer=None, trainable=True, 
                collections=None, caching_device=None,
                partitioner=None, validate_shape=True,
                custom_getter=None)

先看一段代碼:

import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print w_1.name
print w_2.name
#輸出
#w_1:0
#w_1_1:0
import tensorflow as tf
w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#錯誤信息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?

區別:

  • 使用tf.Variable時,如果檢測到命名衝突,系統會自己處理。使用tf.get_variable()時,系統不會處理衝突,而會報錯。
  • tf.Variable()每次都在創建新的對象,與name沒有關係。而tf.get_variable()對於已經創建的同樣name的變量對象,就直接把那個變量對象返回(類似於:共享變量),tf.get_variable() 會檢查當前命名空間下是否存在同樣name的變量,可以方便共享變量。
  • tf.get_variable():對於在上下文管理器中已經生成一個v的變量,若想通過tf.get_variable函數獲取其變量,則可以通過reuse參數的設定爲True來獲取。
  • 還有一點,tf.get_variable()必須寫name,否則報錯(but instead was %s." % (name, shape))
    ValueError: Shape of a new variable (Tensor("truncated_normal:0", shape=(2, 3), dtype=float32)) must be fully defined, but instead was <unknown>.
    ),tf.Variable()不要求。
#需要注意的是tf.get_variable() 要配合reuse和tf.variable_scope() 使用
with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #創建兩個名字一樣的變量會報錯 ValueError: Variable one/v already exists 
with tf.variable_scope("one", reuse = True): #注意reuse的作用。
    c = tf.get_variable("v", [1]) #c.name == "one/v:0" 成功共享,因爲設置了reuse

assert a==c #Assertion is true, they refer to the same object.

對於tf.Variable():

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"  

assert d==e #AssertionError: they are different objects

【tensorflow 學習】tf.get_variable()和tf.Variable()的區別

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章