0

I have a BERT-based sequence classification model that takes as an input 4 strings and out 2 labels for each one:

my_input = [string_1, string_2, string_3, string_4]
out_logits = model(my_input).logits
out_softmax = torch.softmax(out_logits)
out_softmax 
>>> tensor([[0.8666, 0.1334],
        [0.8686, 0.1314],
        [0.8673, 0.1327],
        [0.8665, 0.1335]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

My loss function is nn.CrossEntropyLoss() and my labels are tensors with indices corresponding to the correct labels: tensor([0., 0., 0., 1.]). Note that every label except for one is 1.

loss = loss_fun(out_softmax, labels_tensor)
# step
optim.zero_grad()
loss.backward()
optim.step()

The issue I'm having as appearing above, is that the model learns to just predict one class (e.g., the first column above). Not entirely sure why it's happening, but I thought that penalizing more the prediction that should be 1 might help.

How can I penalize more that prediction?

Penguin
  • 1,923
  • 3
  • 21
  • 51
  • You can always try to calculate the class weights and adding it as a parameter to the loss function, but from my experience this doesn't really work THAT effectively and 99% of the time balancing out the data is better. Is your data imbalanced? What's the ratio of 0:1? – Sean Dec 23 '22 at 01:47
  • `nn.CrossEntropyLoss` expects unnormalised logits as input, so you should pass `out_logits` rather than `out_softmax` to `loss_fun`. On data imbalance, it’s maybe better to check Cross Validated instead, e.g., https://stats.stackexchange.com/questions/247871/what-is-the-root-cause-of-the-class-imbalance-problem – kmkurn Dec 23 '22 at 19:54

1 Answers1

0

You can pass a weight tensor (one weight for each class) to the constructor of nn.CrossEntropyLoss to get such a weighting:

Parameters:

weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C

where C is the number of classes.

But you should also think about alternatives, see the comment of @Sean above or e.g. this question.

cheersmate
  • 2,385
  • 4
  • 19
  • 32