I am trying to construct a custom loss for a regression problem with the following structure, following this answer: Keras Custom loss function to pass arguments other than y_true and y_pred
Now, my function is like the following:
def CustomLoss(model,X_valid,y_valid,batch_size):
def Loss(y_true,y_pred):
n_samples=5
mc_predictions = np.zeros((n_samples,256,256))
for i in range(n_samples):
y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)
(Other operations...)
return LossValue
return Loss
When trying to execute this line
y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)
i get the following error:
Method requires being in cross-replica context, use get_replica_context().merge_call()
From what I gathered I cannot use model.predict inside loss function. Is there a workaround or solution for this? Please let me know if my question is clear or if you need any additional information. Thanks!