TFLearn載入模型NotFoundError: Key ...BatchNormalization/is_training not found in checkpoint解決方法

最近在使用TFLearn來載入AffectNet的TrainedNetwork,採用深度學習提取Valence & Arousal。關於AffectNet是IEEE Transactions on Affective Computing, 2017的論文成果,全名是《AffectNet: A Database for Facial Expression, Valence, and Arousal Computing in the Wild》,作者公開了database和model,可以到項目網頁來獲取信息和申請數據庫和模型,論文可以在arXiv網頁上獲取到。

在通過model.load()的方法載入時報了一個非常長的錯,看起來非常嚇人,核心來說就是:

NotFoundError: Key ResNeXtBlock/BatchNormalization/is_training not found in checkpoint

查了許多資料,發現出現這樣的原因,主要就是因爲模型和代碼不匹配。

通過inspect_checkpoint查看checkpoint模型中的網絡結構,代碼:

import tensorflow as tf
from tensorflow.python.tools import inspect_checkpoint as chkp

chkp.print_tensors_in_checkpoint_file("./model_resnet_-332000", tensor_name=None, all_tensors=True)

剛開頭就顯示出來關於is_training的信息

這說明之前報錯中的is_training在模型中是False的狀態,那在代碼中是什麼樣的呢?

找到代碼中的關鍵一句函數,也是報錯中有顯示到的batch_normalization,這兩者關聯在一起,出錯在這裏的概率最大

net = tflearn.batch_normalization(net)

查看TFLearn關於batch_normalization的文檔說明

tflearn.layers.normalization.batch_normalization (incoming, beta=0.0, gamma=1.0, epsilon=1e-05, decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')

Normalize activations of the previous layer at each batch.

Arguments

  • trainablebool. If True, weights will be trainable.

可以看出,有trainable參數,且默認值爲True,這就跟之前模型中的is_training爲False有衝突了。

把這一句話修改一下

net = tflearn.batch_normalization(net, trainable=False)

其他不修改,load的部分就沒有報錯執行成功了!!

另外,TFLearn存下的模型最好還是用TFLearn來載入,用TensorFlow的saver.restore()方法也會報錯,目前還不知道如何解決,如果有思路的歡迎留言!(參考TensorFlow: NotFoundError: Key not found in checkpoint

KeyError: "The name 'SGD' refers to an Operation not in the graph.

最後,把載入部分的完整代碼結構放出來,出於對AffectNet的版權考慮,如需要完整數據庫和模型,請聯繫作者,這裏也將核心代碼略去,只保留對解決這個問題最關鍵的一行代碼。

from __future__ import division, print_function, absolute_import
import tensorflow as tf
import tflearn

import numpy as np

# Core part of training code of AffectNet. Please contact the author to request the official version.

# Important change of code to solve the problem.
net = tflearn.batch_normalization(net, trainable=False)

# The rest of the training code...

# Loading trained model
model.load("./Valence_Arousal/meta-data/model_resnet_-332000")

# Model loaded and you can do whatever you want now!

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章