今天在訓練的時候發現加載模型的時候提示找不到num_batches_tracked,感到奇怪,因爲之前已經成功訓練過一次了怎麼這次就報錯了呢,後來發現,第一次訓練的時候我用的是0.4.0的pytorch,這次用的是1.0的Pytorch,因爲torch的版本不一樣引起的問題
KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict'
得到類似這樣的報錯
以下參考自這篇文章 https://zhuanlan.zhihu.com/p/91485607
經過研究發現,在pytorch 0.4.1及後面的版本里,BatchNorm層新增了num_batches_tracked參數,用來統計訓練時的forward過的batch數目,源碼如下(pytorch0.4.1):
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
知道原因就知道怎麼處理了,我自己的模型裏沒有num_batches_tracked這個鍵,要把我預訓練模型裏的這個鍵給剔除掉
這是我對我文件裏做的修改,註釋掉的那行是原來的代碼,可以對比一下 新增加的三行和原來的這行,就是簡單的做了一個字典刪除