PyTorch報錯問題解決(1)

報錯信息:
RuntimeError: Attempting to deserialize object on CUDA device 3 but torch.cuda.device_count() is 1. Please use torch.load with map_location to map your storages to an existing device.

原因:加載的模型時用三個GPU訓練的,但現在的電腦只有一個GPU,所以報錯。

解決方法:

model = torch.load(model_path)
改爲
model = torch.load(model_path, map_location='cuda:0')
如果本機有多塊GPU要選擇,則改爲
model = torch.load(model_path, map_location= {'cuda:1': 'cuda:0'})
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章