Variable
最近在學習TensorFlow的過程中, 看到在定義變量的時候有兩種操作:
tf.get_variable()和tf.Variable()。
def weight_variable(shape):
#initial = tf.truncated_normal(shape, stddev=0.1)
#return tf.Variable(initial) # tf.get_variable()
return tf.get_variable(name="w", shape=shape,
initializer=tf.truncated_normal_initializer(mean=0.0,
stddev=0.1,
seed=None,
dtype=tf.float32))
with tf.Session() as sess:
w1 = weight_variable([3, 3, 3, 1, 1])
sess.run(tf.initialize_all_variables())
print(sess.run(w1))
本以爲兩者沒什麼區別,但是博士師兄建議使用tf.get_variable()定義,不解查閱,於是總結了一下兩者區別,如下:
tf.Variable()
tf.Variable(initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None, variable_def=None,
dtype=None, expected_shape=None,
import_scope=None)
tf.get_variable()
tf.get_variable(name, shape=None,
dtype=None, initializer=None,
regularizer=None, trainable=True,
collections=None, caching_device=None,
partitioner=None, validate_shape=True,
custom_getter=None)
先看一段代碼:
import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print w_1.name
print w_2.name
#輸出
#w_1:0
#w_1_1:0
import tensorflow as tf
w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#錯誤信息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?
區別:
- 使用tf.Variable時,如果檢測到命名衝突,系統會自己處理。使用tf.get_variable()時,系統不會處理衝突,而會報錯。
- tf.Variable()每次都在創建新的對象,與name沒有關係。而tf.get_variable()對於已經創建的同樣name的變量對象,就直接把那個變量對象返回(類似於:共享變量),tf.get_variable() 會檢查當前命名空間下是否存在同樣name的變量,可以方便共享變量。
- tf.get_variable():對於在上下文管理器中已經生成一個v的變量,若想通過tf.get_variable函數獲取其變量,則可以通過reuse參數的設定爲True來獲取。
- 還有一點,tf.get_variable()必須寫name,否則報錯(
but instead was %s." % (name, shape))
),tf.Variable()不要求。
ValueError: Shape of a new variable (Tensor("truncated_normal:0", shape=(2, 3), dtype=float32)) must be fully defined, but instead was <unknown>.
#需要注意的是tf.get_variable() 要配合reuse和tf.variable_scope() 使用
with tf.variable_scope("one"):
a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
b = tf.get_variable("v", [1]) #創建兩個名字一樣的變量會報錯 ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True): #注意reuse的作用。
c = tf.get_variable("v", [1]) #c.name == "one/v:0" 成功共享,因爲設置了reuse
assert a==c #Assertion is true, they refer to the same object.
對於tf.Variable():
with tf.variable_scope("two"):
d = tf.get_variable("v", [1]) #d.name == "two/v:0"
e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
assert d==e #AssertionError: they are different objects