Pytorch多卡訓練踩坑記錄——inputs on different devices

pytorch中設置多卡訓練時,操作比較簡便,只要定義了模型,然後加上如下一行指令就可以。

model = nn.DataParallel(model).cuda()

但是,在pytorch中進行多卡訓練還是會遇到一些其他的問題。執行的上述指令後,pytorch會自動在每塊卡上都複製一份模型,同時將input batch size等分。比如設置了8張GPU訓練,那麼每塊卡上的batch size爲所設置的input batch size的1/8,每塊卡上也都複製了相同的一份模型,進行相同的計算。

但是,若在模型初始化時從外部傳遞參數,pytorch則不會複製8份,一般會默認儲存在0號GPU。如果這些變量只存儲和讀取,不參與運算,則不會有什麼問題,但是當這些參數需要參與到網絡計算時(例如點乘),這時便會出現“inputs on different devices”的問題。

例如出現如下問題:

RuntimeError: binary_op(): expected both inputs to be on same device, but input a is on cuda:1 and input b is on cuda:0

以上是進行兩個tensor的點對點相乘操作,其中一個tensor是模型初始化時的外部傳參,默認儲存在GPU 0,另一個tensor是網絡的中間變量,每張GPU上均有一份,這樣,在與除GPU 0 以外的其他卡上的tensor進行相乘時,就會發生以上報錯。

 

解決方案:torch.nn.Parameter

最簡單易行的解決方案便是通過torch.nn.Parameter的方式將所傳參數變成網絡的模塊參數,即成爲網絡的一部分,這時pytorch會自動將其在每塊GPU上都複製一份,可解決上述GPU不匹配問題。

例如模型初始化時傳入的參數爲tensor_input,則在module中可進行如下操作:

self.tensor_input = torch.nn.parameter(tensor_input)

以上方案親測可用,若後期發現更優方案,會再來進行補充。

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章