0

I am hoping to build a model with custom loss function and came across this post Make a custom loss function in keras

            def dice_coef(y_true, y_pred, smooth,thresh):
                y_pred = y_pred>thresh #line with question
                y_true_f = K.flatten(y_true)
                y_pred_f = K.flatten(tf.cast(y_pred,tf.float32))
                intersection = K.sum(y_true_f * y_pred_f)
                return (2. * intersection +smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
            def dice_loss(smooth,thresh):
              def dice(y_true, y_pred):
                return dice_coef(y_true, y_pred, smooth,thresh)
              return dice

            model = Sequential()
            model.add(Dense(1, activation='sigmoid', input_dim=X_train_vectorized.shape[1]))
            model.compile(optimizer='adam', loss=dice_loss(smooth=1e-5,thresh=0.5),
                            metrics=[metrics.mae, metrics.categorical_accuracy])
            model.fit(X_train_vectorized, y_train, nb_epoch=5, validation_data=(X_test_vectorized, y_test))

When I run above lines, y_pred = y_pred>thresh would throw an error as a gradient is not defined. I do not have enough reputation to comment on the origin post.

How should I convert the predicted probabilities to binary outputs? Thanks.

May Y
  • 179
  • 1
  • 20

1 Answers1

1

You can just gather the predictions that satisfies your condition:

y_pred = tf.gather(y_pred, tf.where(y_pred>thresh))

Since tf.gather is a differentiable operation (it behaves like a multiplication with a sparse matrix) you should be able to compute your loss and affecting only the value that satisfied the condition when backpropagating the error.

nessuno
  • 26,493
  • 5
  • 83
  • 74
  • I followed your answer here and am getting `ResourceExhaustedError: OOM when allocating tensor with shape[662885,4,144,144,1]`. Any idea why that is? – SamAtWork May 16 '19 at 00:44
  • The input tensor is too huge. Split it and loop through the splits. Btw it's better to ask a new question – nessuno May 16 '19 at 06:18