PyTorch solution
Well, actually I have gone through docs and you can simply use pos_weight
indeed.
This argument gives weight to positive sample for each class, hence if you have 270
classes you should pass torch.Tensor
with shape (270,)
defining weight for each class.
Here is marginally modified snippet from documentation:
# 270 classes, batch size = 64
target = torch.ones([64, 270], dtype=torch.float32)
# Logits outputted from your network, no activation
output = torch.full([64, 270], 0.9)
# Weights, each being equal to one. You can input your own here.
pos_weight = torch.ones([270])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(0.9))
Self-made solution
When it comes to weighting, there is no built-in solution, but you may code one yourself really easily:
import torch
class WeightedMultilabel(torch.nn.Module):
def __init__(self, weights: torch.Tensor):
self.loss = torch.nn.BCEWithLogitsLoss()
self.weights = weights.unsqueeze()
def forward(outputs, targets):
return self.loss(outputs, targets) * self.weights
Tensor
has to be of the same length as the number of classes in your multilabel classification (270), each giving weight for your specific example.
Calculating weights
You just add labels of every sample in your dataset, divide by the minimum value and inverse at the end.
Sort of snippet:
weights = torch.zeros_like(dataset[0])
for element in dataset:
weights += element
weights = 1 / (weights / torch.min(weights))
Using this approach class occurring the least will give normal loss, while others will have weights smaller than 1
.
It might cause some instability during training though, so you might want to experiment with those values a little (maybe log
transform instead of linear?)
Other approach
You may think about upsampling/downsampling (though this operation is complicated as you would add/delete other classes as well, so advanced heuristics would be needed I think).