tf.data接口,一個batch裏計算多種loss

使用if condition判斷一下

import tensorflow as tf

condition = tf.placeholder(tf.int32, name="data_type")
input_tensor = tf.placeholder(tf.int32, name="input_ids")
A = tf.constant(value=123)

def compute_loss1():
    return tf.abs(A - input_tensor)

def compute_loss2():
    return tf.abs(A + input_tensor)

loss = tf.cond(condition > 0, compute_loss1, compute_loss2)

sess = tf.Session()

feed_dict = {condition:1, input_tensor:100}
print(sess.run(loss,feed_dict=feed_dict))

feed_dict = {condition:-1, input_tensor:100}
print(sess.run(loss,feed_dict=feed_dict))

打印結果:
23
223

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章