1

I have a neural network as below for binary prediction. My classes are heavily imbalanced and class 1 occurs only 2% of times. Showing last few layers only

self.batch_norm2 = nn.BatchNorm1d(num_filters)

self.fc2 = nn.Linear(np.sum(num_filters), fc2_neurons)

self.batch_norm3 = nn.BatchNorm1d(fc2_neurons)

self.fc3 = nn.Linear(fc2_neurons, 1)

My loss is as below. Is this a correct way to calculate pos_weight parameter? I looked into official documentation at this link and it shows that pos_weight needs to have one value for each class for multiclass classification. Not sure if for the binary class it is a difference scenario. I tried to input 2 values and I was getting an error

My question: for binary problem, would pos_weight be a single value unlike multiclass classification where it needs to a list/array with length equal to number of classes?

BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=class_wts[0]/class_wts[1])

My y variable is a single variable that has 0 or 1 to represent the actual class and the neural network outputs a single value

--------------------------------------------------Update 1

based upon the answer by Shai I have below questions:

  1. BCEWithLogitsLoss - if it is a multiclass problem then how to use pos_weigh parameter?
  2. Is there any example of using focal loss in pytorch? I found some links but most of them were old - dating 2 or 3 or more years
  3. For training I am oversampling my class 1. Is focal loss still appropiate?
user2543622
  • 5,760
  • 25
  • 91
  • 159

1 Answers1

2

The documentation of pos_weight is indeed a bit unclear. For BCEWithLogitsLoss pos_weight should be a torch.tensor of size=1:

BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_wts[0]/class_wts[1]]))

However, in your case, where pos class occurs only 2% of the times, I think setting pos_weight will not be enough.
Please consider using Focal loss:
Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár Focal Loss for Dense Object Detection (ICCV 2017).
Apart from describing Focal loss, this paper provides a very good explanation as to why CE loss performs so poorly in the case of imbalance. I strongly recommend reading this paper.

Other alternatives are listed here.

Shai
  • 111,146
  • 38
  • 238
  • 371