import tensorflow as tf
condition = tf.placeholder(tf.int32, name="condition")
A = tf.constant(value=123)
B = tf.constant(value=321)
def func1():
return A
def func2():
return B
y = tf.cond(condition > 0, func1, func2) # tensor 和 非tensor 比較
sess = tf.Session()
feed_dict = {condition:1}
print(sess.run(y,feed_dict=feed_dict))
feed_dict = {condition:-1}
print(sess.run(y,feed_dict=feed_dict))
打印結果:
123
321