1

I am trying to compute cross_entropy loss manually in Pytorch for an encoder-decoder model.

I used the code posted here to compute it: Cross Entropy in PyTorch

I updated the code to discard padded tokens (-100). The final code is this:

class compute_crossentropyloss_manual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:  # <------ I added this if-statement
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss

To verify that it works fine, I tested it on a text generation task, and I computed the loss using pytorch.nn implementation and using this code.

The loss values are not identical:

using nn.CrossEntropyLoss:

enter image description here

Using the code from the link above:

enter image description here

Am I missing something?

I tried to get the source code of nn.CrossEntropyLoss but I wasn't able. In this link nn/functional.py at line 2955, you will see that the function points to another cross_entropy loss called torch._C._nn.cross_entropy_loss; I can't find this function in the repo.

Edit:

I noticed that the differences appear only when I have -100 tokens in the gold.

Demo example:

y = torch.randint(1, 50, (100, 50), dtype=float)
x = torch.randint(1, 50, (100,))

x[40:] = -100
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 25.55788695847976
> 10.223154783391905

and when we don't have -100:

x[40:] = 30 # any positive number
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 24.684453267596453
> 24.684453267596453
Minions
  • 5,104
  • 5
  • 50
  • 91
  • `torch._C` is the C source. So you could look in [here](https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/LossNLL.cpp) – Chrispresso Dec 02 '21 at 18:23
  • Thanks @Chrispresso. I am not able to understand anything in C. – Minions Dec 02 '21 at 18:33

2 Answers2

0

I solved the problem by updating the code. I discarded before the -100 tokens (the if-statement above), but I forgot to reduce the hidden_state size (which is called n_batch in the code above). After doing that, the loss numbers are identical to the nn.CrossEntropyLoss values. The final code:

class CrossEntropyLossManual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:
                n_batch -= 1
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss
Minions
  • 5,104
  • 5
  • 50
  • 91
  • Note also that the [torch implementation for softmax](https://github.com/pytorch/pytorch/blob/7541c9f8be0d46444adf6073051b6c7d0da4ac65/torch/csrc/jit/tensorexpr/operators/softmax.cpp#L9) is bit different to avoid overflows. – null Jan 19 '23 at 06:23
0

I needed this too - thank you for the manual Cross Entropy Loss code. It matches the pytorch results perfectly (with my data). I have one little fix to your fix above. In the end you need to divide by the final count of non-ignored rows (those without label -100). So you need a counter:

class compute_crossentropyloss_manual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        cnt = 0             # <----- I added this
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
            cnt += 1        # <----- I added this
        loss = - loss/cnt   # <---- I changed this from nbatch to 'cnt'
        return loss