I have a BERT-based sequence classification model that takes as an input 4 strings and out 2 labels for each one:
my_input = [string_1, string_2, string_3, string_4]
out_logits = model(my_input).logits
out_softmax = torch.softmax(out_logits)
out_softmax
>>> tensor([[0.8666, 0.1334],
[0.8686, 0.1314],
[0.8673, 0.1327],
[0.8665, 0.1335]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
My loss function is nn.CrossEntropyLoss()
and my labels are tensors with indices corresponding to the correct labels: tensor([0., 0., 0., 1.])
. Note that every label except for one is 1
.
loss = loss_fun(out_softmax, labels_tensor)
# step
optim.zero_grad()
loss.backward()
optim.step()
The issue I'm having as appearing above, is that the model learns to just predict one class (e.g., the first column above). Not entirely sure why it's happening, but I thought that penalizing more the prediction that should be 1
might help.
How can I penalize more that prediction?