0

I currently want to use Tensorflows Object Detection API for my custom problem. I already created the dataset, but its pretty unbalanced. The Dataset has 3 classes and my main problem is, that one class has about 16k samples and another class has only about 2.5k samples.

So I think I have to balance the dataset. Someone told me, that there is something called sample/class weights(Not sure if this is 100% correct), which balance the samples for training, so that the biggest class has a smaller impact on training then the smallest class.

I'm not able to find this method for balancing. Can someone pleas give me a hint where to start?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Tripton
  • 51
  • 7
  • 1
    Possible duplicate of [Unbalanced data and weighted cross entropy](https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy) – P-Gn Jun 02 '18 at 16:55
  • Does this answer your question? [Class weights for balancing data in TensorFlow Object Detection API](https://stackoverflow.com/questions/51862997/class-weights-for-balancing-data-in-tensorflow-object-detection-api) – desertnaut Mar 10 '22 at 10:32

1 Answers1

0

You can do normal cross entropy, giving you a ? x 1 tensor, X of losses

If you want class number N to count T times more, you can do

X = X * tf.reduce_sum(tf.multiply(one_hot_label, class_weight), axis = 1)

tf.multiply

scales the label by whatever weight you want,

tf.reduce_sum

converts the label vector a to a scalar, so you end up with a ? x 1 tensor filled with the class weightings. Then you simply multiply the tensor of losses with the tensor of weightings to achieve desired results.

Since one class is 6.4 times more common than the other, I would apply the weightings 1 and 6.4 to the more common and less common class respectively. This will mean that every time the less common class occurs, it has 6.4 times the affect of the more common class, so it's like it saw the same number of samples from each.

You might want to modify it so that the weighting add up to the number of classes. This matches the default case is all of the weightings are 1. In that case we have 1 /7.4 and 6.4/7.4

Cory Nezin
  • 1,551
  • 10
  • 22
  • Thank you for your anwser! As far as I understand thats for editing network source file. Do you know where to edit for Object Detection API? Or is there an other way for doing this for object detection api? – Tripton Jun 03 '18 at 16:59
  • Sorry, not familiar with that API. However if you don't want to much around in the internals, you could also build/modify your dataset so that it oversamples the less frequent class. – Cory Nezin Jun 04 '18 at 01:47