0

I am trying to build a weighted MAE loss function as my data is highly imbalanced. losses.MeanAbsoluteError provides an argument sample_weights, but only in its __call__() function. So I am trying to wrap this function with my own:

def weighted_mean_absolute_error(class_weights):
    def loss(y_true, y_pred):
        weights = tf.map_fn(fn=lambda t: class_weights.get(t, 1.0), elems=y_true)
        mae = losses.MeanAbsoluteError()
        v = mae(y_true, y_pred, sample_weights=weights)
        return v
    return loss

class_weights is a dict mapping integers created by a tokenizer to the reciprocal relativ occurrence of the respective integer.

But this function fails at training time, as y_true is a symbolic tensor. By training at the latest there should be actuall values.

Is it possible to wrap the function in this way and I am just missing something? Or do I need to implement a weighted mean absolute error myself?

mlinke-ai
  • 21
  • 8
  • Does this https://stackoverflow.com/questions/56401346/mean-absolute-error-in-tensorflow-without-built-in-functions/56401550 solve your problem? – abdou_dev Sep 29 '21 at 09:52
  • I will try to implement the given solution with the weights and come back later with the results. Thanks – mlinke-ai Sep 29 '21 at 10:14
  • 1
    As I suspected: `y_true` and `y_pred` are symbolic tensors. Therefore the difference between them is also a symbolic tensor. And a symbolic tensor can not be passed to a numpy function. So replace `np.abs()` with `tf.math.abs()` and `np.average()` with `tf.math.reduce_mean()`. I assume the class weights also need to be a tensor and have to be multiplied after applying `tf.math.abs()`. Hopefully this solves my issue. – mlinke-ai Sep 29 '21 at 10:35
  • I hope, does the link that I've posted solve your issue now? or you still have a problem? – abdou_dev Sep 29 '21 at 10:53
  • I did it. You are the best. Thanks – mlinke-ai Sep 29 '21 at 11:02

1 Answers1

0

The comment from @abdou_dev lead me on the right track. He postet a link to this page. Posted there is following solution for a self made mean absolute error loss funktion:

import numpy as np
MAE = np.average(np.abs(y_true - y_pred), weights=sample_weight, axis=0)

However this DOES NOT work.

y_true and y_pred are symbolic tensors and can therefore not be passed to a numpy function. Even the difference is a symbolic tensor and it does not work.

SOLUTION

You need to replace np.abs() with tf.math.abs() and np.average() with tf.math.reduce_mean(). Those are TensorFlow functions and work with symbolic tensors too. Unfortunately tf.math.reduce_mean() does not provide a field for weights. After calculating the absolute value you need to multiply it with the weights. But for this operation the weights need to be a tensor. With a tf.lookup.KeyValueTensorInitializer() and a tf.lookup.StaticHashTable() you can replace the labels with their weights in a copy of y_true. Also I needed to cast y_true to tf.float32 as my labels are integers but the model outputs are floats.

The full solution is as following:

def weighted_mean_absolute_error(class_weights):
    def loss(y_true, y_pred):
        init = tf.lookup.KeyValueTensorInitializer(
            list(class_weights.keys()), list(class_weights.values())
        )
        table = tf.lookup.StaticHashTable(init, default_value=1.0)
        weights = table.lookup(y_true)
        return tf.math.reduce_mean(
            tf.math.multiply(tf.math.abs(tf.cast(y_true, tf.float32) - y_pred), weights)
        )
    return loss

Assumed you have the dict class_weights mapping labels to reciprocal relative occurrence of the respective label you use the function like this:

wmae = weighted_mean_absolute_error(class_weights)
# your model definition comes here like:
# model = keras.Sequential()
# model.add(...)
# ...
model.compile(
    optimizer=optimizers.Adam(), # or another optimizer of your choosing
    loss=wmae,
    metrics=[metrics.CategoricalAccuracy()], # or other metrics
)

I do not know why Google does not implement a losses.WeightedMeanAbsoluteError() themself but this should work.

mlinke-ai
  • 21
  • 8