3

I'm trying to obtain the gradient of the loss objective, in my case categorical_cross_entropy w.r.t to NN parameters such as 'weights' and 'bias'.

The reason for this is I want to implement a callback function with the above as the base, with which I could debug the model while it's training.

So, here's the problem.

I'm currently using generator methods to fit, evaluate and predict on the dataset.

The categorical_cross_entropy loss function in Keras is implemented as follows:

def categorical_crossentropy(y_true, y_pred):
    return K.categorical_crossentropy(y_true, y_pred)

The only way I can get my hands on y_pred is if I evaluate/predict at the end of training my model.

So, what I'm asking is the following:

  • Is there a way for me to create a callback as mentioned above?
  • If anyone already has implemented a callback like the one above using categorical_cross_entropy, please let me know how to make it work?
  • Lastly, how to compute the numeric gradient for the same?

Currently, this is the code I'm using to calculate the gradient. But, I've no clue if this is right/wrong. Link.

def symbolic_gradients(model, input, output):
    grads = K.gradients(model.total_loss, model.trainable_weights)
    inputs = model.model._feed_inputs + model.model._feed_targets + 
    model.model._feed_sample_weights
    fn = K.function(inputs, grads)

    return fn([input, output, np.ones(len(output))])

Ideally I'd like to make this model-agnostic, but even if it's not, it's okay.

halfer
  • 19,824
  • 17
  • 99
  • 186
Jagan S
  • 41
  • 3

1 Answers1

0

I can help with gradient part. I am using this function to calculate gradient of the loss function w.r.t output.

def get_loss_grad(model, inputs, outputs):
    x, y, sample_weight = model._standardize_user_data(inputs, outputs)
    grad_ce = K.gradients(model.total_loss, model.output)
    func = K.function((model._feed_inputs + model._feed_targets + model._feed_sample_weights), grad_ce)
    return func(x + y + sample_weight)
Damodharan_C
  • 313
  • 1
  • 2
  • 11