Using Tensorflow 2.10, I am trying to implement a custom loss function that takes the model inputs as well as y_true and y_pred arguments.
inputs = tf.keras.layers.Input(shape=X.shape[-1:], batch_size=16)
dense1 = tf.keras.layers.Dense(8, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(dense1)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
def custom_loss(i):
def loss(y_true, y_pred):
mask = tf.equal(y_true, tf.cast(tf.round(y_pred), 'float32'))
selec = tf.where(mask, 1., 0.)
x1 = (1 - tf.reduce_mean(selec))
mask = tf.reshape(selec, (-1,))
m = tf.boolean_mask(i, mask)
x2 = tf.reduce_mean(m[:, -1])
return x1 * x2
return loss
model.compile(loss=custom_loss(inputs),
optimizer='adam',
metrics=['accuracy',
tf.keras.metrics.Precision(name='precision')])
h = model.fit(X_train,
y_train,
validation_data=(X_test, y_test),
epochs=5,
batch_size=16,
verbose=2)
When trying to fit the model I got the following error:
TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.float32, name=None), name='Placeholder:0', description="created by layer 'tf.cast_17'"), an intermediate Keras symbolic input/output, to a TF API that does not allow registering custom dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, or `tf.map_fn`. Keras Functional model construction only supports TF API calls that *do* support dispatching, such as `tf.math.add` or `tf.reshape`. Other APIs cannot be called directly on symbolic Kerasinputs/outputs. You can work around this limitation by putting the operation in a custom Keras layer `call` and calling that layer on this symbolic input/output.
Similar posts proposed to disable eager execution:
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
But this leads to another error in the loss function:
ValueError: Variable <tf.Variable 'dense_14/kernel:0' shape=(13, 8) dtype=float32> has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.