使用if condition判斷一下
import tensorflow as tf
condition = tf.placeholder(tf.int32, name="data_type")
input_tensor = tf.placeholder(tf.int32, name="input_ids")
A = tf.constant(value=123)
def compute_loss1():
return tf.abs(A - input_tensor)
def compute_loss2():
return tf.abs(A + input_tensor)
loss = tf.cond(condition > 0, compute_loss1, compute_loss2)
sess = tf.Session()
feed_dict = {condition:1, input_tensor:100}
print(sess.run(loss,feed_dict=feed_dict))
feed_dict = {condition:-1, input_tensor:100}
print(sess.run(loss,feed_dict=feed_dict))
打印結果:
23
223