Pytorch多GPU訓練DataParallel的使用

Pytorch官網有個簡單的示例

https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

其實用起來還是比較簡單的,大致如下:

from torch.nn import DataParallel


model = model.cuda()
model = DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

# AttributeError: 'DataParallel' object has no attribute XXX
model.module.XXX

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章