代碼:
import tensorflow as tf
input = [[1, 2, 3], [4, 5, 6]]
graph = tf.compat.v1.Graph()
# 通過tf.Variable對象創建變量
with graph.as_default():
input_tf = tf.Variable(input, dtype=tf.float32, name="input")
print("input_tf shape: ", input_tf.get_shape().as_list())
print("input_tf dtype: ", input_tf.dtype)
with tf.compat.v1.Session(graph=graph) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
print("input_tf value :\n", sess.run(input_tf))
print()
# 通過tf.get_variable對象創建變量
graph1 = tf.compat.v1.Graph()
# 創建變量初始化器
initializer = tf.compat.v1.constant_initializer(input)
with graph1.as_default():
input1_tf = tf.compat.v1.get_variable(name="input", shape=[2, 3], initializer=initializer)
print("input1_tf shape: ", input1_tf.get_shape().as_list())
print("input1_tf dtype: ", input1_tf.dtype)
with tf.compat.v1.Session(graph=graph1) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
print("input1_tf value :\n", sess.run(input1_tf))
輸出:
input_tf shape: [2, 3]
input_tf dtype: <dtype: 'float32'>
input_tf value :
[[1. 2. 3.]
[4. 5. 6.]]
input1_tf shape: [2, 3]
input1_tf dtype: <dtype: 'float32'>
input1_tf value :
[[1. 2. 3.]
[4. 5. 6.]]