I have a torch.nn.module class defined in the following way:
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.sub_module_a = .... # nn.module
self.sub_module_b_dict = {
'B': .... # nn.module
}
However after I call torch.nn.DataParallel(MyModule)
and MyModule.to(device)
only sub_module_a
is put on cuda. The 'B' inside self.sub_module_b_dict
is still on CPU.
Looks like DataParallel and to(device) only support first level variables inside a torch.nn.Module class. The modules nested inside a customized structure (in this case, a dictionary) seem to be ignored.
Am I missing some caveats here?