0

The following problem has occurred while tackling a reinforcement learning problem. In my code I eventually get to the following problem, when calculating the loss: My neural network outputs 4 q-values (given a state as input, it outputs the q-value for each action that could be taken given that state --> 4 possible actions). Now for the DQN algorithm, I want to calculate the loss between y_true, which represents basically the discounted reward and therefore is only a scalar, and the q-value of the action my agent actually took (so just one of the four q-values). There I can't use the custom MSE available in keras. I think I need something of the following structure:

import keras.backend as kb

def custom_loss_function(batch_action_taken):
    def loss(y_true, y_pred):
        q_value = ? #need to extract the q_values from y_pred according to the action taken in batch_action_taken
        return kb.mean(kb.square(q_value - y_true), axis=-1)
    return loss

But since the values within the function are keras objects, I am not sure how to code the line where I put the comment.

I hope, I explained myself well. I also googled a lot but I could not find an answer!

EDIT: For clarification purposes:

batch_action_taken: vector of length 36 and each entry is either 0,1,2 or 3

y_true: vector of length 36

y_pred: is of size (36, 4) (for each data point, in total 36, 4 possible actions)

q_value: should be a vector of length 36 with each entry corresponding to one value in one row in y_pred depending on the corresponding value of batch_action_taken

Peter
  • 183
  • 1
  • 1
  • 9

1 Answers1

0

This might help,

# Based on your [action_size]
actions = kb.placeholder(shape=(None,[action_size]), dtype='float32')

q_value = kb.sum(kb.dot(batch_action_taken, actions), axis=1) 
return kb.mean(kb.square(q_value - y_true), axis=1)
  • Sadly, I get the AttributeError: 'tuple' object has no attribute 'rank' for the second line of your code – Peter Nov 17 '20 at 16:30
  • Sorry to hear that. I'm not sure why is that but maybe it is related to this problem: https://stackoverflow.com/questions/62744659/attributeerror-tuple-object-has-no-attribute-rank-when-calling-fit-on-a-ker –  Nov 17 '20 at 16:41