Pytorch 學習(九):Pytorch 數據和模型存取

Pytorch 數據和模型存取

本方法總結自《動手學深度學習》(Pytorch版)github項目

  • Pytorch 存儲和讀取主要依靠 load 和 save 函數
  • 模型存取依靠 load_state_dict() 函數

數據存儲與讀取

import torch

path = 'p.pth'  # 'p.pt'
a = torch.tensor(1)
torch.save(a, path)
b = torch.load(path)

模型存取

  • 僅存儲/加載模型參數
model = net()
state_dict = model.state_dict()  # 模型狀態
torch.save(state_dict, path)
model2 = net()
model2.load_state_dict(torch.load(path))
  • 存儲/加載整個模型
model = net()
torch.save(model, path)
model2 = torch.load(path)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章