Pytorch 數據和模型存取
本方法總結自《動手學深度學習》(Pytorch版)github項目
- Pytorch 存儲和讀取主要依靠 load 和 save 函數
- 模型存取依靠 load_state_dict() 函數
數據存儲與讀取
import torch
path = 'p.pth' # 'p.pt'
a = torch.tensor(1)
torch.save(a, path)
b = torch.load(path)
模型存取
- 僅存儲/加載模型參數
model = net()
state_dict = model.state_dict() # 模型狀態
torch.save(state_dict, path)
model2 = net()
model2.load_state_dict(torch.load(path))
- 存儲/加載整個模型
model = net()
torch.save(model, path)
model2 = torch.load(path)