The comment from @abdou_dev lead me on the right track.
He postet a link to this page.
Posted there is following solution for a self made mean absolute error loss funktion:
import numpy as np
MAE = np.average(np.abs(y_true - y_pred), weights=sample_weight, axis=0)
However this DOES NOT work.
y_true
and y_pred
are symbolic tensors and can therefore not be passed to a numpy function. Even the difference is a symbolic tensor and it does not work.
SOLUTION
You need to replace np.abs()
with tf.math.abs()
and np.average()
with tf.math.reduce_mean()
. Those are TensorFlow functions and work with symbolic tensors too. Unfortunately tf.math.reduce_mean()
does not provide a field for weights. After calculating the absolute value you need to multiply it with the weights. But for this operation the weights need to be a tensor. With a tf.lookup.KeyValueTensorInitializer()
and a tf.lookup.StaticHashTable()
you can replace the labels with their weights in a copy of y_true
. Also I needed to cast y_true
to tf.float32
as my labels are integers but the model outputs are floats.
The full solution is as following:
def weighted_mean_absolute_error(class_weights):
def loss(y_true, y_pred):
init = tf.lookup.KeyValueTensorInitializer(
list(class_weights.keys()), list(class_weights.values())
)
table = tf.lookup.StaticHashTable(init, default_value=1.0)
weights = table.lookup(y_true)
return tf.math.reduce_mean(
tf.math.multiply(tf.math.abs(tf.cast(y_true, tf.float32) - y_pred), weights)
)
return loss
Assumed you have the dict class_weights
mapping labels to reciprocal relative occurrence of the respective label you use the function like this:
wmae = weighted_mean_absolute_error(class_weights)
# your model definition comes here like:
# model = keras.Sequential()
# model.add(...)
# ...
model.compile(
optimizer=optimizers.Adam(), # or another optimizer of your choosing
loss=wmae,
metrics=[metrics.CategoricalAccuracy()], # or other metrics
)
I do not know why Google does not implement a losses.WeightedMeanAbsoluteError()
themself but this should work.