0

I want to scale the model output and renormalize it to deal with the class imbalance issue. For example, if I have 10-labels outputs y_logits and their softmax y_pred and prior p, the new output should be:

y_pred /= prior
y_pred /= sum(y_pred)

The problem is that softmax_cross_entropy_with_logits function in tensorflow takes the logits y_logits and I need to do this scaling on y_pred instead. Any idea how to do that without implementing the cross-entropy loss myself?

Ehab AlBadawy
  • 3,065
  • 4
  • 19
  • 31
  • Does this help https://stackoverflow.com/questions/40198364/how-can-i-implement-a-weighted-cross-entropy-loss-in-tensorflow-using-sparse-sof ? – Maxim Jan 11 '18 at 15:39
  • I see that there is someone suggested to do it with `sparse_softmax_cross_entropy`, is there a way to do a similar thing with `softmax_cross_entropy_with_logits` instead? – Ehab AlBadawy Jan 11 '18 at 18:44
  • The same trick works for both – Maxim Jan 11 '18 at 18:54

1 Answers1

0

For those who are faceing the same problem, I've found a good solution for it by reimplementing the CE in a numerically stable way. If you want to know why you shouldn't implement CE directly as its equation says -∑ p_i log(q_i) check out this tutorial.

The implementation I used to apply the priors works as follows:

def modified_CE(logits=None, labels=None, priors=None):
    # subtracting the maximum value to prevent inf results
    # you should change the shape of your logits based on your data
    scaled_logits = logits - tf.reshape(tf.reduce_max(logits,1),shape=(7500,1))
    # renormalize your logits as a finale step for the log softmax function
    normalized_logits = scaled_logits - tf.reshape(tf.reduce_logsumexp(scaled_logits,1),shape=(7500,1))

    # apply the priors
    normalized_logits -= tf.log(np.array(priors,dtype=np.float32))
    # renormalize 
    normalized_logits -= tf.reshape(tf.reduce_logsumexp(normalized_logits,1),shape=(7500,1))

    return tf.reduce_mean(-tf.reduce_sum(labels[0,:,:]*normalized_logits,1))
Ehab AlBadawy
  • 3,065
  • 4
  • 19
  • 31