0

I am confused how to reduce dimenisons in dice loss function for segmentation.
Here is input shape (B,C,H,W)
and there are two ways to reduce
Sum spatial dimension only first, then take the mean of batch and channels

reduce_axis=[2,3]
denominator = torch.sum(true, dim=reduce_axis) + torch.sum(pred, dim=reduce_axis)
dice=1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
torch.mean(dice)  # the batch and channel average

Sum all dimensions first except batch then average batch

reduce_axis=[1,2,3]
denominator = torch.sum(true, dim=reduce_axis) + torch.sum(pred, dim=reduce_axis)
dice=1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
torch.mean(dice)  # the batch and channel average
Shai
  • 111,146
  • 38
  • 238
  • 371
Talha Anwar
  • 2,699
  • 4
  • 23
  • 62

1 Answers1

0

It depends on the meaning of the different dimensions.
If your channel dimension means segmentation masks of different classes (aka "semantic segmentation"), then computing Dice per-channel and then averaging over channels and batch will give you the mean Dice per-class.
Alternatively, if you compute dice over channels as well (2nd option), you have a multi-class Dice (I'm not sure if this is even a "thing") and you report the average of it over the batch.

What is it that you want to compute?

Shai
  • 111,146
  • 38
  • 238
  • 371