0

I am currently using the following loss function:

loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits, labels))

However, my loss quickly approaches zero since there are ~1000 classes and only a handful of ones for any example (see attached image) and the algorithm is simply learning to predict almost entirely zeroes. I'm worried that this is preventing learning even though the loss continues to creep slightly towards zero. Are there any alternative loss functions that I should consider?

enter image description here

reese0106
  • 2,011
  • 2
  • 16
  • 46
  • there is nothing wrong with the learning first discovering the mean prediction (all zeros in your case) and then training properly. This is equivalent of first learning p(y) and then p(y|x) (so it starts with marginalised distribution and then conditions it). – lejlot Sep 01 '17 at 21:42
  • Agreed - I think you would expect to see in most cases a logarithmatic pattern to your loss over time, and this is no exception - it'll learn the low hanging fruit first, and the lowest hanging fruit is all zeros. If your model is built well the loss should keep going down. – James Sep 01 '17 at 21:50
  • I recognize that there is nothing wrong with this, but am just wondering if there are any alternatives to consider in a scenario such as this. – reese0106 Sep 01 '17 at 21:55

1 Answers1

0

Did you try to project one multi-label target vector into multiple one-hot vectors?

Bear with me for a moment. For brevity I will build the loss function in numpy.

Apply sigmoids on your model outputs. Let's call it y. This will be the probabilities for each class. Here for simplicity I will sample from unit uniform.

y = np.random.uniform(0,1,size=[5]) # inferred
y_true = np.array([1, 0, 0, 1, 0]) #multi-label target vector
projection = y_true*np.identity(5) #expand each label into one separate one-hot vector
cross_entropy = -projection*np.log(y) # cross entropy for each label
loss = np.sum(cross_entropy) # sum cross entropies for different labels

I belive that now the biggiest weight in calculating the gradients will fall in the nonzero elements (the labels) and the gradients will point in the direction that pleases all the labels.

Am I missing something?

prometeu
  • 679
  • 1
  • 8
  • 23