Keras 中加入lambda層無法正常載入模型問題

剛剛解決了這個問題,現在記錄下來

 

問題描述

當使用lambda層加入自定義的函數後,訓練沒有bug,載入保存模型則顯示Nonetype has no attribute 'get'

 

 

問題解決方法:

這個問題是由於缺少config信息導致的。lambda層在載入的時候需要一個函數,當使用自定義函數時,模型無法找到這個函數,也就構建不了。

 

m = load_model(path,custom_objects={"reduce_mean":self.reduce_mean,"slice":self.slice})

其中,reduce_mean 和slice定義如下

    def slice(self,x, turn):
        """ Define a tensor slice function
        """
        return x[:, turn, :, :]
    def reduce_mean(self, X):
        return K.mean(X, axis=-1)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章