I am confused how to reduce dimenisons in dice loss function for segmentation.
Here is input shape (B,C,H,W)
and there are two ways to reduce
Sum spatial dimension only first, then take the mean of batch and channels
reduce_axis=[2,3]
denominator = torch.sum(true, dim=reduce_axis) + torch.sum(pred, dim=reduce_axis)
dice=1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
torch.mean(dice) # the batch and channel average
Sum all dimensions first except batch then average batch
reduce_axis=[1,2,3]
denominator = torch.sum(true, dim=reduce_axis) + torch.sum(pred, dim=reduce_axis)
dice=1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
torch.mean(dice) # the batch and channel average