1

I am trying to implement a multi class dice loss function in tensorflow. Since it is multi class dice, I need to convert the probabilities of each class into its one-hot form. For example, if my network outputs these probabilities:
[0.2, 0.6, 0.1, 0.1] (assuming 4 classes)
I need to convert this into:
[0 1 0 0]
This can be done by using tf.argmax followed by tf.one_hot

def generalized_dice_loss(labels, logits):
 #labels shape [batch_size,128,128,64,1] dtype=float32
 #logits shape [batch_size,128,128,64,7] dtype=float32
 labels=tf.cast(labels,tf.int32)
 smooth = tf.constant(1e-17)
 shape = tf.TensorShape(logits.shape).as_list()
 depth = int(shape[-1])
 labels = tf.one_hot(labels, depth, dtype=tf.int32,axis=4)
 labels = tf.squeeze(labels, axis=5)
 logits = tf.argmax(logits,axis=4)
 logits = tf.one_hot(logits, depth, dtype=tf.int32,axis=4)
 numerator = tf.reduce_sum(labels * logits, axis=[1, 2, 3])
 denominator = tf.reduce_sum(labels + logits, axis=[1, 2, 3])
 numerator=tf.cast(numerator,tf.float32)
 denominator=tf.cast(denominator,tf.float32)
 loss = tf.reduce_mean(1.0 - 2.0*(numerator + smooth)/(denominator + smooth))
 return loss

Problem is, tf.argmax is not differentiable, It will throw an error:

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

How to solve this problem? Can we do the same thing without using tf.argmax?

tejal567
  • 109
  • 1
  • 11

1 Answers1

3

Take a look at How is the smooth dice loss differentiable?. You won't need to do the conversion (convert [0.2, 0.6, 0.1, 0.1] to [0 1 0 0]). Just leave them as the continuous value between 0 and 1.

If I understand correctly, the loss function is just a surrogate to achieve your expected objective. Even though it is not the same, as long as it is a good proxy, it is fine (otherwise, it is not differentiable).

In the evaluation time, feel free to use the tf.argmax to get the real metric.

greeness
  • 15,956
  • 5
  • 50
  • 80