最近在使用TFLearn來載入AffectNet的TrainedNetwork,採用深度學習提取Valence & Arousal。關於AffectNet是IEEE Transactions on Affective Computing, 2017的論文成果,全名是《AffectNet: A Database for Facial Expression, Valence, and Arousal Computing in the Wild》,作者公開了database和model,可以到項目網頁來獲取信息和申請數據庫和模型,論文可以在arXiv或網頁上獲取到。
在通過model.load()的方法載入時報了一個非常長的錯,看起來非常嚇人,核心來說就是:
NotFoundError: Key ResNeXtBlock/BatchNormalization/is_training not found in checkpoint
查了許多資料,發現出現這樣的原因,主要就是因爲模型和代碼不匹配。
通過inspect_checkpoint查看checkpoint模型中的網絡結構,代碼:
import tensorflow as tf
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file("./model_resnet_-332000", tensor_name=None, all_tensors=True)
剛開頭就顯示出來關於is_training的信息
這說明之前報錯中的is_training在模型中是False的狀態,那在代碼中是什麼樣的呢?
找到代碼中的關鍵一句函數,也是報錯中有顯示到的batch_normalization,這兩者關聯在一起,出錯在這裏的概率最大
net = tflearn.batch_normalization(net)
查看TFLearn關於batch_normalization的文檔說明
tflearn.layers.normalization.batch_normalization (incoming, beta=0.0, gamma=1.0, epsilon=1e-05, decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
Normalize activations of the previous layer at each batch.
Arguments
- trainable:
bool
. If True, weights will be trainable.
可以看出,有trainable參數,且默認值爲True,這就跟之前模型中的is_training爲False有衝突了。
把這一句話修改一下
net = tflearn.batch_normalization(net, trainable=False)
其他不修改,load的部分就沒有報錯執行成功了!!
另外,TFLearn存下的模型最好還是用TFLearn來載入,用TensorFlow的saver.restore()方法也會報錯,目前還不知道如何解決,如果有思路的歡迎留言!(參考TensorFlow: NotFoundError: Key not found in checkpoint)
KeyError: "The name 'SGD' refers to an Operation not in the graph.
最後,把載入部分的完整代碼結構放出來,出於對AffectNet的版權考慮,如需要完整數據庫和模型,請聯繫作者,這裏也將核心代碼略去,只保留對解決這個問題最關鍵的一行代碼。
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import tflearn
import numpy as np
# Core part of training code of AffectNet. Please contact the author to request the official version.
# Important change of code to solve the problem.
net = tflearn.batch_normalization(net, trainable=False)
# The rest of the training code...
# Loading trained model
model.load("./Valence_Arousal/meta-data/model_resnet_-332000")
# Model loaded and you can do whatever you want now!