0

I'm training a model that was adapted from Matterport's implementation of Mask R-CNN with unbalanced data. I have modified a loss function to apply class weights to a classification loss function as suggested previously using a one-hot representation of the true class, but what I would really like to do is a more general case where I can have matrix of weights to provide different penalties for the predicted class given the actual label class, as alluded to in the last paragraph of the previous answer.

def mrcnn_class_loss_graph(target_class_ids, pred_class_logits,
                           active_class_ids):
    """Loss for the classifier head of Mask RCNN.
    batch = 1
    num_classes = 8  
    num_rois = variable
    target_class_ids: [batch, num_rois]. Integer class IDs. Uses zero padding to fill in the array.
    pred_class_logits: [batch, num_rois, num_classes]
    active_class_ids: [batch, num_classes]. Has a value of 1 for classes that are in the dataset of 
        the image, and 0 for classes that are not in the dataset.
    """
    # During model building, Keras calls this function with target_class_ids of type float32. 
    # Unclear why. Cast it to int to get around it.
    target_class_ids = tf.cast(target_class_ids, 'int64')

    # Find predictions of classes that are not in the dataset.
    pred_class_ids = tf.argmax(input=pred_class_logits, axis=2)
    pred_active = tf.gather(active_class_ids[0], pred_class_ids)

    # penalty_matrix = tf.constant([
    #         [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
    #         [1.0, 0.0, 1.0, 0.6, 1.0, 0.2, 1.5, 1.0],
    #         [1.0, 1.0, 0.0, 0.2, 1.0, 0.2, 0.8, 1.0],
    #         [1.0, 0.4, 0.2, 0.0, 1.0, 0.2, 1.0, 1.0],
    #         [1.0, 1.2, 1.2, 1.2, 0.0, 1.5, 0.8, 0.8],
    #         [1.0, 0.5, 0.5, 0.5, 1.0, 0.0, 1.0, 1.0],
    #         [1.0, 2.0, 1.0, 1.0, 1.0, 0.5, 0.0, 1.0],
    #         [1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.8, 0.0]
    #     ])

    class_weights = tf.constant([[1.0, 1.0, 0.5, 1.0, 2.0, 2.0, 2.0, 2.0]])

    one_hot = tf.one_hot(target_class_ids, depth = class_weights.shape[1], on_value=1.0, off_value=0.0)
    # deduce weights for batch samples based on their true label
    weights = tf.reduce_sum(class_weights * one_hot, axis=2)
    # compute your (unweighted) softmax cross entropy loss
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=one_hot, logits=pred_class_logits)

    # apply the weights, relying on broadcasting of the multiplication
    loss = loss * weights

    # Erase losses of predictions of classes that are not in the active classes of the image.
    loss = loss * pred_active

    # Compute loss mean. Use only predictions that contribute to the loss to get a correct mean.
    loss = tf.reduce_sum(input_tensor=loss) / tf.reduce_sum(input_tensor=pred_active)

    return loss

In the above code, I've commented out a potential 8x8 penalty matrix wherein the entry j of each row i corresponds to the penalty of misidentifying an object of class i as class j; the diagonal zeros represent a correct classification.

I was wondering how to correctly (and efficiently) get the weights from such a penalty matrix based on the true and predicted classes. I'm using TensorFlow 2.5 and Python 3.8.

JSK
  • 101
  • 3

1 Answers1

0

Implementation for PyTorch:

import torch

batch_size = 2
num_classes = 3
penalty_matrix = torch.tensor([[0, 2, 3], [1, 0, 2], [4, 5, 0]]).long()
criterion = torch.nn.CrossEntropyLoss(reduction='none')  

logits = torch.randn(batch_size, num_classes, requires_grad=True)
target = torch.empty(batch_size, dtype=torch.long).random_(num_classes)

output = criterion(logits, target)
loss = output * penalty_matrix[logits.max(dim=-1)[1], target]

loss.backward()