Edit: I have tried PyTorch 1.6.0 and 1.7.1, both give me the same error.
I have a model that allows users to easily switch between different architectures A and B. The forward functions for both architectures are different too, so I have the following model class:
P.S. I am just using a very simple example here to demonstrate my problem, the actual model is much more complicated.
class Net(nn.Module):
def __init__(self, condition):
super().__init__()
self.linear = nn.Linear(10, 1)
if condition == 'A':
self.forward = self.forward_A
elif condition == 'B':
self.linear2 = nn.Linear(10, 1)
self.forward = self.forward_B
def forward_A(self, x):
return self.linear(x)
def forward_B(self, x1, x2):
return self.linear(x1) + self.linear2(x2)
It works well in a single GPU case. In the multi-GPU case, however, it throws me an error.
device= 'cuda:0'
x = torch.randn(8,10).to(device)
model = Net('B')
model = model.to(device)
model = nn.DataParallel(model)
model(x, x)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_addmm)
How to make this model class works with nn.DataParallel
?