如何保存和讀取pytorch模型

如何保存和讀取pytorch模型


相信大家也會遇到這樣的問題吧,在使用pytorch訓練自己模型的時候,如果不將我們訓練的模型保存起來,我們每一次都是從頭開始訓練我們的模型,這樣真的很麻煩。其實在我的上一篇博客中我已經發現這個問題了。

1.保存模型

#定義保存模型函數
def save_model(the_model,PATH):
    torch.save(the_model.state_dict(),PAT

當我們的模型訓練完畢之後,我們只需調用一下該函數就可以了

save_model(cnn,'cnn.pth')
#這裏的cnn就是我要保存的訓練好的模型,cnn.pth就是要保存爲的名稱,
#一般來說pytorch的模型後綴都是.pth

2.讀取模型

例如我們想要在另外的一個python文件中讀取我們之前已經保存好的模型,我們需要先創建一個和之前模型一樣的空模型來接收。

import torch
from cnn_test import CNN

best_model=CNN()
#定義一個與之前模型一致結構的模型來接收
best_model.load_state_dict(torch.load('cnn.pth'))
#加載之前的模型,這裏的‘cnn.pth’就是我上一步保存的模型文件
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章