3

I have written my custom training loop using tf.GradientTape(). My data has 2 classes. The classes are not balanced; class1 data contributes almost 80% and class2 contributes remaining 20%. Therefore in order to remove this imbalance I was trying to write custom loss function which will take into account this imbalance and apply the corresponding class weights and calculate the loss. i.e. I want to use the class_weights = [0.2, 0.8]. I am not able to find similar examples.

However all the examples I am seeing are using model.fit approach where its easier to pass the class_weights. I am not able to find out the example which uses class_weights with custom training loop using tf.GradientTape.

I did go through the suggestions of using sample_weight, however I don't have the data where in I can specify the weights for samples, therefore my preference is to use class weight.

I am using BinaryCrossentropy loss as loss function but I want to change the loss based on the class_weights. That's where I am stuck, how to tell BinaryCrossentropy to consider the class_weights.

Is my approach of using custom loss function correct or there is better way to make use of class_weights while training with custom training loop (not using model.fit)?

np2314
  • 645
  • 5
  • 14

2 Answers2

1

you can write your own loss function. in that loss function call BinaryCrossentropy and then multiply the result in the weight you want and return that

amin
  • 279
  • 4
  • 14
  • amin thanks for your suggestion. I am struggling to do the same thing due to following observations. When I get the loss from BinaryCrossentropy loss function, its returned as a single value, therefore I am not able to calculate the weighted average. I was trying to get the loss for each class so that I could apply the weight for each class but looks like loss function doesn't work that way. – Pravin Girase Dec 29 '20 at 10:56
  • I looked at discussion thread at https://github.com/keras-team/keras/issues/2115 but there are too many discussion points, so I am kind of lost. I tried to use one of the weighted loss functions but getting errors related to reshaping. May be that thread is discussing about categorical cross entropy but I am trying to use it for BinaryCrossentropy. – Pravin Girase Dec 29 '20 at 11:09
0

Here's an implementation that should work for n classes instead of just 2.

For your example of 80:20 split, calculate weights as below (assuming 100 samples in total).

  1. Weight calculation (ref: Handling Class Imbalance: TensorFlow):

    weight_class_0 = (1/count_for_class_0) * (total_samples / num_classes) # (80%) 0.625
    weight_class_1 = (1/count_for_class_1) * (total_samples / num_classes) # (20%) 2.5
    class_wts = tf.constant([weight_class_0, weight_class_1])
    
  2. Loss function: Requires labels to be sparse and logits unscaled (no activations applied).

    # Example logits=[[-3.2, 2.0], [1.2, 0.5], ...], (sparse)labels=[0, 1, ...]
    def weighted_sparse_categorical_crossentropy(labels, logits, weights):
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits)
        class_weights = tf.gather(weights, labels)
        return tf.reduce_mean(class_weights * loss)
    

You can supply this loss function to custom training loops.

Karan Shah
  • 417
  • 6
  • 21