Error(s) in loading state_dict for XXX Unexpected key(s) in state_dict, 找不到num_batches_tracked

今天在訓練的時候發現加載模型的時候提示找不到num_batches_tracked,感到奇怪,因爲之前已經成功訓練過一次了怎麼這次就報錯了呢,後來發現,第一次訓練的時候我用的是0.4.0的pytorch,這次用的是1.0的Pytorch,因爲torch的版本不一樣引起的問題

KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict'

得到類似這樣的報錯

以下參考自這篇文章 https://zhuanlan.zhihu.com/p/91485607

經過研究發現,在pytorch 0.4.1及後面的版本里,BatchNorm層新增了num_batches_tracked參數,用來統計訓練時的forward過的batch數目,源碼如下(pytorch0.4.1): 

    if self.training and self.track_running_stats:
        self.num_batches_tracked += 1
        if self.momentum is None:  # use cumulative moving average
            exponential_average_factor = 1.0 / self.num_batches_tracked.item()
        else:  # use exponential moving average
            exponential_average_factor = self.momentum

知道原因就知道怎麼處理了,我自己的模型裏沒有num_batches_tracked這個鍵,要把我預訓練模型裏的這個鍵給剔除掉

這是我對我文件裏做的修改,註釋掉的那行是原來的代碼,可以對比一下 新增加的三行和原來的這行,就是簡單的做了一個字典刪除

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章