Here is the Model structure. I combined almost everything together to for the convenience of hyper-parameters tuning. ''' class MultiTaskDNN(nn.Module):
def __init__(self, n_tasks,
input_dim=1024,
output_dim=1,
hidden_dim=[1024, 100],
inits=['xavier_normal', 'kaiming_uniform'],
act_function=['relu', 'leaky_relu'],
dropouts=[0.10, 0.25],
batch_norm=True):
super(MultiTaskDNN, self).__init__()
self.n_tasks = n_tasks
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.act_function = act_function
self.batch_norm = batch_norm
current_dim = input_dim
self.layers = nn.ModuleList()
self.dropouts = nn.ModuleList()
self.bns = nn.ModuleList()
for k, hdim in enumerate(hidden_dim):
self.layers.append(nn.Linear(current_dim, hdim))
self.bns.append(nn.BatchNorm1d(hdim, eps=2e-1))
current_dim = hdim
if inits[k] == 'xavier_normal':
nn.init.xavier_normal_(self.layers[k].weight)
elif inits[k] == 'kaiming_normal':
nn.init.kaiming_normal_(self.layers[k].weight)
elif inits[k] == 'xavier_uniform':
nn.init.xavier_uniform_(self.layers[k].weight)
elif inits[k] == 'kaiming_uniform':
nn.init.kaiming_uniform_(self.layers[k].weight)
self.dropouts.append(nn.Dropout(dropouts[k]))
# n_targets
self.heads = nn.ModuleList()
for _ in range(self.n_tasks):
self.heads.append(nn.Linear(current_dim, output_dim))
def forward(self, x):
for k, layer in enumerate(self.layers):
x = layer(x)
if self.act_function[k] == 'sigmoid':
x = torch.sigmoid(x)
elif self.act_function[k] == 'relu':
x = F.relu(x)
elif self.act_function[k] == 'leaky_relu':
x = F.leaky_relu(x)
if self.batch_norm == True:
x = self.bns[k](x)
x = self.dropouts[k](x)
outputs = []
for head in self.heads:
outputs.append(head(x))
return outputs
'''
Please also let me know if the structure looks right. After training this multi-task model which has, say, 10 tasks (heads). I only want to predict task 7 which is head No.7. How should I load the model and do the prediction? Thank you.
model.state_dict()
MultiTaskDNN(
(layers): ModuleList(
(0): Linear(in_features=1024, out_features=128, bias=True)
(1): Linear(in_features=128, out_features=128, bias=True)
)
(dropouts): ModuleList(
(0): Dropout(p=0.25, inplace=False)
(1): Dropout(p=0.25, inplace=False)
)
(bns): ModuleList(
(0): BatchNorm1d(128, eps=0.2, momentum=0.1, affine=True, track_running_stats=True)
(1): BatchNorm1d(128, eps=0.2, momentum=0.1, affine=True, track_running_stats=True)
)
(heads): ModuleList(
(0): Linear(in_features=128, out_features=1, bias=True)
(1): Linear(in_features=128, out_features=1, bias=True)
(2): Linear(in_features=128, out_features=1, bias=True)
(3): Linear(in_features=128, out_features=1, bias=True)
(4): Linear(in_features=128, out_features=1, bias=True)
(5): Linear(in_features=128, out_features=1, bias=True)
(6): Linear(in_features=128, out_features=1, bias=True)
(7): Linear(in_features=128, out_features=1, bias=True)
(8): Linear(in_features=128, out_features=1, bias=True)
(9): Linear(in_features=128, out_features=1, bias=True)
(10): Linear(in_features=128, out_features=1, bias=True)
)
)