I am confused about the calculation of cross entropy in Pytorch. If I want to calculate the cross entropy between 2 tensors and the target tensor is not a one-hot label, which loss should I use? It is quite common to calculate the cross entropy between 2 probability distributions instead of the predicted result and a determined one-hot label.
The basic loss function CrossEntropyLoss
forces the target as the index integer and it is not eligible in this case. BCELoss
seems to work but it gives an unexpected result. The expected formula to calculate the cross entropy is
But BCELoss
calculates the BCE of each dimension, which is expressed as
-yi*log(pi)-(1-yi)*log(1-pi)
Compared with the first equation, the term -(1-yi)*log(1-pi)
should not be involved. Here is an example using BCELoss
and we can see the second term is involved in each dimension's result. And that make the result different from the correct one.
import torch.nn as nn
import torch
from math import log
a = torch.Tensor([0.1,0.2,0.7])
y = torch.Tensor([0.2,0.2,0.6])
L = nn.BCELoss(reduction='none')
y1 = -0.2 * log(0.1) - 0.8 * log(0.9)
print(L(a, y))
print(y1)
And the result is
tensor([0.5448, 0.5004, 0.6956])
0.5448054311250702
If we sum the results of all the dimensions, the final cross entropy doesn't correspond to the expected one. Because each one of these dimensions involves the -(1-yi)*log(1-pi)
term. In constrast, Tensorflow can calculate the correct cross entropy value with CategoricalCrossentropy
. Here is the example with the same setting and we can see the cross entropy is calculated in the same way as the first formula.
import tensorflow as tf
from math import log
L = tf.losses.CategoricalCrossentropy()
a = tf.convert_to_tensor([0.1,0.2,0.7])
y = tf.convert_to_tensor([0.2,0.2,0.6])
y_ = -0.2* log(0.1) - 0.2 * log(0.2) - 0.6 * log(0.7)
print(L(y,a), y_)
tf.Tensor(0.9964096, shape=(), dtype=float32) 0.9964095674488687
Is there any function can calculate the correct cross entropy in Pytorch, using the first formula, just like CategoricalCrossentropy
in Tensorflow?