5

I have built a Keras model for image segmentation (U-Net). However in my samples some misclassifications (areas) are not that important, while other are crucial, so I want to assign higher weight in loss function to them. To complicate things further, I would like some misclassifications (class 1 instead of 2) to have very high penalty while inverse (class 2 instead of 1) shouldn't be penalized that much.

The way I see it, I need to use a sum (across all of the pixels) of weighted categorical crossentropy, but the best I could find is this:

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

However this code only works with a single prediction and my knowledge of Keras inner workings is lacking (and math side of it is not much better). Anyone know how I can adapt it, or even better, is there a ready-made loss function which would suit my case?

I would appreciate some pointers.

EDIT: my question is similar to How to do point-wise categorical crossentropy loss in Keras?, except that I would like to use weighted categorical crossentropy.

Community
  • 1
  • 1
johndodo
  • 17,247
  • 15
  • 96
  • 113

1 Answers1

1

You can use weight maps (as proposed in the U-Net paper). In those weight maps, you can weight regions with more weight or less weight. Here is some pseudocode:

loss = compute_categorical_crossentropy()
weighted_loss = loss * weight_map # using element-wise multiplication
Simdi
  • 794
  • 4
  • 13
  • I am also doing something similar to the original poster, but with binary crossentropy. I assume the same pseudocode applies? – n88b Apr 20 '19 at 03:43