今天幫妹子調試tensorflow的程序,遇到了nan的問題,找了好久終於解決,也沒辜負妹子。
最終找到了問題是tf.sqrt, 引自stackoverflow, Why is my loss function returning nan?
解釋爲:
It was coming from the fact that x was approaching a tensor with all zeros for entries. This was making the derivative of sigma wrt x a NaN. The solution was to add a small value to that quantity.
也就是tf.sqrt(x), x爲0導致的nan的問題。 當x爲0, 導致導數爲NaN。
解決方案: 加一個極小量避免x爲0,也就是:
tf.sqrt(x+1e-8)
再給大家推薦一下知乎相似問題,很有借鑑意義。爲什麼用tensorflow訓練網絡,出現了loss=nan,accuracy總是一個固定值?