5

Given batched RGB images as input, shape=(batch_size, width, height, 3)

And a multiclass target represented as one-hot, shape=(batch_size, width, height, n_classes)

And a model (Unet, DeepLab) with softmax activation in last layer.

I'm looking for weighted categorical-cross-entropy loss funciton in kera/tensorflow.

The class_weight argument in fit_generator doesn't seems to work, and I didn't find the answer here or in https://github.com/keras-team/keras/issues/2115.

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        # y_true, y_pred shape is (batch_size, width, height, n_classes)
        loos = ?...
        return loss

    return wcce
Balraj Ashwath
  • 1,407
  • 2
  • 13
  • 19
Mendi Barel
  • 3,350
  • 1
  • 23
  • 24
  • By multiclass target do you mean more than 1 possible outcomes are considered? – SajanGohil Dec 29 '19 at 17:47
  • What do you mean by "outcome"? Multiclass=Different pixel value indicate different class. And you can have more than 2 classes. (2 classes=binary classification) – Mendi Barel Dec 29 '19 at 19:38
  • Multiclass classification is a different kind of classification problem where more than 1 class can be true, I got confused with that. – SajanGohil Dec 30 '19 at 15:34

3 Answers3

6

I will answer my question:

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        Kweights = K.constant(weights)
        if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
        y_true = K.cast(y_true, y_pred.dtype)
        return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
    return wcce

Usage:

loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)
Mendi Barel
  • 3,350
  • 1
  • 23
  • 24
1

I'm using the Generalized Dice Loss. It works better than the Weighted Categorical Crossentropy in my case. My implementation is in PyTorch, however, it should be fairly easy to translate it.

class GeneralizedDiceLoss(nn.Module):
    def __init__(self):
        super(GeneralizedDiceLoss, self).__init__()

    def forward(self, inp, targ):
        inp = inp.contiguous().permute(0, 2, 3, 1)
        targ = targ.contiguous().permute(0, 2, 3, 1)

        w = torch.zeros((targ.shape[-1],))
        w = 1. / (torch.sum(targ, (0, 1, 2))**2 + 1e-9)

        numerator = targ * inp
        numerator = w * torch.sum(numerator, (0, 1, 2))
        numerator = torch.sum(numerator)

        denominator = targ + inp
        denominator = w * torch.sum(denominator, (0, 1, 2))
        denominator = torch.sum(denominator)

        dice = 2. * (numerator + 1e-9) / (denominator + 1e-9)

        return 1. - dice
Jonas Stepanik
  • 135
  • 4
  • 11
0

This issue might be similar to: Unbalanced data and weighted cross entropy which has an accepted answer.

Balraj Ashwath
  • 1,407
  • 2
  • 13
  • 19