我們在訓練分類模型時,需要輸出模型預測的正確率用以評估,下面的代碼片段可以實現這個功能。
# y_pred是模型的輸出值,取值在[0,1]
# label是真實值,0或1
one = tf.ones_like(y_pred)
zero = tf.zeros_like(y_pred)
label_pred = tf.where(y_pred < 0.5, x=zero, y=one)
acc_op = tf.metrics.accuracy(
labels=label, predictions=label_pred, name='acc_op')
參考
- tf.compat.v1.metrics.accuracye: https://www.tensorflow.org/api_docs/python/tf/compat/v1/metrics/accuracy
- tf.where:https://www.tensorflow.org/api_docs/python/tf/where