I am porting a keras model over to torch
and I'm having trouble replicating the exact behavior of keras/tensorflow's 'categorical_crossentropy'
after a softmax layer. I have some workarounds for this problem, so I'm only interested in understanding what exactly tensorflow calculates when calculating categorical cross entropy.
As a toy problem, I set up labels and predicted vectors
>>> import tensorflow as tf
>>> from tensorflow.keras import backend as K
>>> import numpy as np
>>> true = np.array([[0.0, 1.0], [1.0, 0.0]])
>>> pred = np.array([[0.0, 1.0], [0.0, 1.0]])
And calculate the Categorical Cross Entropy with:
>>> loss = tf.keras.losses.CategoricalCrossentropy()
>>> print(loss(pred, true).eval(session=K.get_session()))
8.05904769897461
This differs from the analytical result
>>> loss_analytical = -1*K.sum(true*K.log(pred))/pred.shape[0]
>>> print(loss_analytical.eval(session=K.get_session()))
nan
I dug into the source code for keras/tf's cross entropy (see Softmax Cross Entropy implementation in Tensorflow Github Source Code) and found the c function at https://github.com/tensorflow/tensorflow/blob/c903b4607821a03c36c17b0befa2535c7dd0e066/tensorflow/compiler/tf2xla/kernels/softmax_op.cc line 116. In that function, there is a comment:
// sum(-labels *
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
// along classes
// (The subtraction broadcasts along the batch dimension.)
And implementing that, I tried:
>>> max_logits = K.max(pred, axis=0)
>>> max_logits = max_logits
>>> xent = K.sum(-true * ((pred - max_logits) - K.log(K.sum(K.exp(pred - max_logits)))))/pred.shape[0]
>>> print(xent.eval(session=K.get_session()))
1.3862943611198906
I also tried to print the trace for xent.eval(session=K.get_session())
, but the trace is ~95000 lines long. So it begs the question: what exactly is keras/tf doing when calculating 'categorical_crossentropy'
? It makes sense that it doesn't return nan
, that would cause training issues, but where does 8 come from?