I have implemented focal loss in Pytorch with using of this paper. And ran into a problem with loss - got nan as loss function value.
This is implementation of focal loss:
def focal_loss(y_real, y_pred, gamma = 2):
y_pred = torch.sigmoid(y_pred)
return -torch.sum((1 - y_pred)**gamma * y_real * torch.log(y_pred) +
y_pred**gamma * (1 - y_real) * torch.log(1 - y_pred))
Train loop and my SegNet are working, I think so, because I have tested them with dice and bce losses.
I think errors occurs in backprop. Why can it be? Maybe my implementation is wrong?