RuntimeError Expected tensor for argument #1 'input' to have the same device as tensor for argument

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

記錄一下bug,一開始以爲是gpu並行設備指定的問題,後來發現是模型中的函數調用問題。

主要整理來自於pytorch的issue

以下面兩個module爲例,如果用DataParallel, testModule會報錯 而testModule2不會。

import torch
from torch import nn

class testModule(nn.Module):
    def __init__(self):
        super(testModule, self).__init__()
        self.g = nn.Conv2d(in_channels=1, out_channels=1,
                         kernel_size=1, stride=1, padding=0)
        self.operation_function = self._realOperation

    def forward(self, x):
        output = self.operation_function(x)
        return output

    def _realOperation(self, x):
        x = self.g(x)
        return x

class testModule2(nn.Module):
    def __init__(self):
        super(testModule2, self).__init__()
        self.g = nn.Conv2d(in_channels=1, out_channels=1,
                         kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        x = self.g(x)
        return x

if __name__ == '__main__':
        input = torch.rand(4, 1, 1, 1).cuda()
        net = testModule()
        net2 = testModule2()
        gpu_num = torch.cuda.device_count()
        print('GPU NUM: {:2d}'.format(gpu_num))
        if gpu_num > 1:
            net = torch.nn.DataParallel(net, list(range(gpu_num))).cuda()
            net2 = torch.nn.DataParallel(net2, list(range(gpu_num))).cuda()
        out2 = net2(input)
        print(out2.size())
        out = net(input)
        print(out.size())
        self.operation_function = self._realOperation

原因很簡單,在上面這句話涉及到一個屬性綁定對應的方法,當將模塊廣播到不同的GPU時,此屬性(不是張量)僅被複制,這意味着此模塊的所有廣播副本都具有引用相同綁定方法的屬性,並且此方法綁定到同一實例,因此使用相同的self.g,僅在GPU 0上具有所有參數。因此在GPU 1上出錯。

在testModule2中,在每個廣播副本的前面,動態找到的self.g是該副本的g屬性,其參數會廣播到相應的GPU。

解決方法:

涉及到方法綁定的地方直接放到forward中執行,也就是把self._realOperation放到forward中就可以了

另外還有將self.operation_function編寫爲該類的另一種方法(這個測試了一下好像不行,不知道是不是自己的寫法有問題

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章