-1

I want to create a loss function where the MSE is only calculated on a subset of the outputs. The subset depends on the input data. I used the answer to this question to figure out how to create a custom function based on the input data:

Custom loss function in Keras based on the input data

However, I'm having trouble implementing the custom function to work.

Here is what I've put together.

def custom_loss(input_tensor):


    def loss(y_true, y_pred):
        board = input_tensor[:81]
        answer_vector = board == .5
        #assert np.sum(answer_vector) > 0

        return K.mean(K.square(y_pred * answer_vector - y_true), axis=-1)
    return loss


def build_model(input_size, output_size):
    learning_rate = .001
    a = Input(shape=(input_size,))
    b = Dense(60, activation='relu')(a)
    b = Dense(60, activation='relu')(b)
    b = Dense(60, activation='relu')(b)
    b = Dense(output_size, activation='linear')(b)
    model = Model(inputs=a, outputs=b)
    model.compile(loss=custom_loss(a), optimizer=Adam(lr=learning_rate))

    return model

model = build_model(83, 81)

I want the MSE to treat the output as 0 wherever the board is not equal to 0.5. (The true value is one hot encoded with the one being within the subset). For some reason my output my output is treated as always zero. In other words, the custom loss function doesn't seem to be finding any places where the board is equal to 0.5.

I can't tell if I'm misinterpretting the dimensions or if the comparisons are failing due to the tensors, or even if there is just a generally much easier approach to do what I'm trying.

Allen
  • 236
  • 3
  • 12
  • does ```answer_vector = board == .5``` create a tensor ? did you try using ```tf.where``` ? you can add a print(answer_vector) to show during model build time whether this is a tensor (or a constant). I would expect that the comparison would yield "tensor == 0.5" => False. – Pedro Marques Jun 27 '19 at 14:03
  • 1
    As a debugging tool, I try to create a batch with a 1/2 elements and then use ```var = K.print_tensor(var)``` inside the loss function to track how the function transforms the values. – Pedro Marques Jun 27 '19 at 14:06

1 Answers1

0

The problem is that answer_vector = board == .5 is not what you think it is. It is not a tensor, but the boolean value False, since board is a tensor and 0.5 is a number:

a = tf.constant([0.5, 0.5])
print(a == 0.5) # False

Now, a * False is a vector fo zeros:

with tf.Session() as sess:
   print(sess.run(a * False)) # [0.0, 0.0]

You need to use tf.equal instead of ==. Another possible pitfall is that comparing floats with equality is dangerous, see e.g. What's wrong with using == to compare floats in Java?

tomkot
  • 926
  • 5
  • 7