4

Edit: I have tried PyTorch 1.6.0 and 1.7.1, both give me the same error.

I have a model that allows users to easily switch between different architectures A and B. The forward functions for both architectures are different too, so I have the following model class:

P.S. I am just using a very simple example here to demonstrate my problem, the actual model is much more complicated.

class Net(nn.Module):
    def __init__(self, condition):
        super().__init__()
        self.linear = nn.Linear(10, 1)
        
        if condition == 'A':
            self.forward = self.forward_A
        elif condition == 'B':
            self.linear2 = nn.Linear(10, 1)
            self.forward = self.forward_B
            
    def forward_A(self, x):
        return self.linear(x)
    
    def forward_B(self, x1, x2):
        return self.linear(x1) + self.linear2(x2)
    

It works well in a single GPU case. In the multi-GPU case, however, it throws me an error.

device= 'cuda:0'
x = torch.randn(8,10).to(device)

model = Net('B')
model = model.to(device)
model = nn.DataParallel(model)

model(x, x)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_addmm)

How to make this model class works with nn.DataParallel?

Raven Cheuk
  • 2,903
  • 4
  • 27
  • 54

2 Answers2

0

You are forcing the input x and the model to be on 'cuda:0' device, but when working on multiple GPUs, you should not specify any particular device.
Try:

x = torch.randn(8,10)  
model = Net('B')
model =  nn.DataParallel(model, device-ids=[0, 1]).cuda()  # assuming 2 GPUs
pred = model(x, x)
Shai
  • 111,146
  • 38
  • 238
  • 371
  • Sorry, it does not work either, I have this error instead. "RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu". I am using torch 1.6.0, I am not sure if it is a bug for this specific version – Raven Cheuk Jul 13 '21 at 00:55
  • I have tried PyTorch 1.6.0 and 1.7.1. Both versions give me the same result. – Raven Cheuk Jul 13 '21 at 01:16
  • Now the error becomes: "RuntimeError: Caught RuntimeError in replica 1 on device 1." – Raven Cheuk Jul 13 '21 at 06:47
  • @RavenCheuk do you have a more specific error? What `RuntimeError` exactly? Can you run with `nn.DataParallel` with only one device and see if you get a more detailed description of the error? – Shai Jul 13 '21 at 07:18
  • The full error is "RuntimeError: Expected tensor for 'out' to have the same device as tensor for argument #2 'mat1'; but device 0 does not equal 1 (while checking arguments for addmm)" If I run nn.DataParallel with only one device, there is no error. – Raven Cheuk Jul 13 '21 at 07:20
0

This problem goes away if you have 2 wrappers each calling this model with their own forward functions.

Also you need to use nn.DataParallel instead of nn.Module.

  • 1
    Your answer could be improved with additional supporting information. Please [edit] to add further details, such as citations or documentation, so that others can confirm that your answer is correct. You can find more information on how to write good answers [in the help center](/help/how-to-answer). – Community Sep 14 '22 at 12:40