pytorch導入模型參數

  • 背景介紹:
  1. 我的想法是把一個預訓練的網絡的參數導入到我的模型中,但是預訓練模型的參數只是我模型參數的一小部分,怎樣導進去不出差錯了,請來聽我說說。
  • 解法
  1. 首先把你需要添加參數的那一小部分模型提取出來,並新建一個類進行重新定義,如圖向Alexnet中添加前三層的參數,重新定義前三層。
  2. 接下來就是導入參數

  3.         checkpoint = torch.load(config.pretrained_model)
            # change name and load parameters
            model_dict = model.net1.state_dict()
            checkpoint = {k.replace('features.features', 'featureExtract1'): v for k, v in checkpoint.items()}
            checkpoint = {k:v for k,v in checkpoint.items() if k in model_dict.keys()}
    
            model_dict.update(checkpoint)
            model.net1.load_state_dict(model_dict)
  4. 程序如上圖所示,主要是第三、四句,第三是替換,別人訓練的模型參數的鍵和自己的定義的會不一樣,所以需要替換成自己的;第四句有個if用於判斷導入需要的參數。其他語句都相當於是模板,套用即可。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章