4

I'm currently working on using tensorflow to adress a multi-class segmentation problem on a SegNet Architecture

My classes are heavily unbalanced and thus I need to integrate the median frequency balancing (using weights on classes on loss calculation). I use the following tip (based on this post) to apply Softmax. I need help to extend it in order to add the weights, I'm not sure how to do it. Current implementation:

reshaped_logits = tf.reshape(logits, [-1, nClass])
reshaped_labels = tf.reshape(labels, [-1])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(reshaped_logits, reshaped_labels)

My idea would be:

  1. To split the logits tensor in nClass tensor,
  2. Apply the softmax on each independently,
  3. Weight them with median frequency balancing
  4. Finally, summing the weighted losses.

Would that be the right approach?

Thanks

Community
  • 1
  • 1

1 Answers1

0

You can find the code to do that here

def _compute_cross_entropy_mean(class_weights, labels, softmax):
    cross_entropy = -tf.reduce_sum(tf.multiply(labels * tf.log(softmax), class_weights),
                                   reduction_indices=[1])

    cross_entropy_mean = tf.reduce_mean(cross_entropy,
                                        name='xentropy_mean')
    return cross_entropy_mean

where head is your class weighing matrix.

stochastic_zeitgeist
  • 1,037
  • 1
  • 14
  • 21