0

I have a function (see below) which modifies the loss function so that it returns only the loss for the K samples in the minibatch with the lowest loss. The idea is to focus in each optimization step on these samples.

So I do first a forward pass to get the loss value for each sample in the mini-batch, then adapt the loss via the fn. “get_adapted_loss_for_minibatch”.

As the adapted loss takes into account only a certain fraction of the samples in the minibatch (I am using currently 60% of the samples), I was expecting that I get also a measurable speedup during training, as the backward step has to be done only for a fraction of the samples in the minibatch.

But unfortunately this is not the case, the training takes practically the same amount of time as when I am using all samples in the minibatch (so when I do not adapt the loss). I am using a ‘densenet121’ network, and training is done on CIFAR-100.

Am I doing something wrong ? Should I disabled autograd for some samples in the minibatch manually ? I though the ‘topk’ function would do that automatically.

def get_adapted_loss_for_minibatch(loss):
    # Returns the loss containing only the samples of the mini-batch with the _lowest_ loss
    # Parameter 'loss' must be a vector containing the per-sample loss for all samples in the (original) minibatch
    minibatch_size = loss.size()[0]
    r = 0.6 * minibatch_size
    # round r to integer, safeguard if r is 0
    r = max(round(r), 1)
    # The 'topk' function returns the loss for the 'r' samples with the _lowest_ loss in the minibtach
    # See documentation at https://pytorch.org/docs/stable/generated/torch.topk.html
    # Note the 'topk' operation is differentiable, see https://stackoverflow.com/questions/67570529/derive-the-gradient-through-torch-topk
    # and https://math.stackexchange.com/questions/4146359/derivative-for-masked-matrix-hadamard-multiplication
    loss_adapted = torch.topk(loss, r, largest = False, sorted = False, dim = 0)[0]
    # return it
    return loss_adapted
user2454869
  • 105
  • 1
  • 11
  • Could you provide the type of loss that would precede the `get_adapted_loss_for_minibatch` call as well as the actual shape of the provided `loss` argument? Also what is your batch size? – Ivan Aug 25 '21 at 10:28
  • Of course. The (mini-)batch size is 128. I construct the loss function via ```self.criterion_no_reduce = torch.nn.CrossEntropyLoss(reduction='none') ``` and calculate the original loss via ```loss = self.criterion_no_reduce(logits, y)```. The shape of the provided loss is (128) and the shape of the returned loss is (77) – user2454869 Aug 25 '21 at 14:54
  • Ok the `reduction='none'` is important to note – Ivan Aug 25 '21 at 14:57
  • I first thought it was because of other parts in the training loop which are time-consuming, but the training is running with 80% GPU load/utilization, so most of the runtime is actually from the forward/backward pass... Training is done on the GPU (using CUDNN), on a NVIDIA Quadro RTX 6000 on Ubuntu. Pytorch 1.9, Cuda Toolkit 11 – user2454869 Aug 25 '21 at 15:01
  • Do you have batch normalization in your model? – Ivan Aug 25 '21 at 15:13
  • Yes, the densenet model I use has a batch-normalization layer. See https://github.com/juntang-zhuang/Adabelief-Optimizer/blob/update_0.2.0/PyTorch_Experiments/classification_cifar10/models/densenet.py – user2454869 Aug 25 '21 at 15:29

1 Answers1

0

The reason why you see no difference in the speed for training is that you are using batch normalization. In turn, this means your gradient still depends on the entirety of the batch, even if you are only using part of the batch's content to compute the final loss term and backpropagate.

Mathematically speaking the running statistics measured in each batch normalization layer will involve all elements in the batch.

If you look at the mean computation (of course batchnorms also involve standard deviation measurements). Intuitively when you normalize a given vector by its average, the resulting vector's elements will depend on all of the elements of the initial vector since all of them were used precisely to compute the average.

If you want to read more about this, you can read more on this post which is about backpropagating through x / x.mean(0).


Following our discussion, here is a way to replace your BatchNorm2d layers with GroupNorm. Go through the sub-modules of your network, then replace all instances of BatchNorm2d with new initialized instances of GroupNorm. I will give you an example with Bottleneck:

>>> net = Bottleneck(10, 2)

>>> for name, module in net.named_children():
...   if isinstance(module, nn.BatchNorm2d):
...     setattr(net, name, 
...        nn.GroupNorm(num_channels=module.num_features, num_groups=num_groups))

This way net will look like:

>>> net
Bottleneck(
  (bn1): GroupNorm(2, 10, eps=1e-05, affine=True)
  (conv1): Conv2d(10, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn2): GroupNorm(2, 8, eps=1e-05, affine=True)
  (conv2): Conv2d(8, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Many thanks for your answer ! Is it possible to adapt either the loss or the (batch-norm layers in the) model so that the batch-normalization is done only from the "top-K" samples in the batch with the lowest loss ? I do not want to get rid of the batch-normalization layers in DenseNet completely, as I suppose without it it would perform much worse. – user2454869 Aug 25 '21 at 15:56
  • You can only infer the top-k samples after having computed the forward pass on your model. So no you won't know the top-k at the point in time when you are computing the statistics in your batchnorms. You could however compute some forward passes with full batch, find out the top-k, then not include those 'non' top-k samples in the following forward passes... – Ivan Aug 25 '21 at 16:01
  • So to reap runtime benefits and still have some kind of normalization in my model, I might have to replace the batch-norm layer with e.g. "group normalization" (https://arxiv.org/pdf/1803.08494.pdf) or another alternative (https://analyticsindiamag.com/alternatives-batch-normalization-deep-learning/, https://towardsdatascience.com/an-alternative-to-batch-normalization-2cee9051e8bc) ... – user2454869 Aug 25 '21 at 16:05
  • That would be worth a shot; At least, any normalization technique not involving the batch should yield speed improvements. – Ivan Aug 25 '21 at 16:08
  • Is there an 'easy' way to replace all the 'BatchNorm2d' layers in Densenet implementation (https://github.com/juntang-zhuang/Adabelief-Optimizer/blob/update_0.2.0/PyTorch_Experiments/classification_cifar10/models/densenet.py) with a Pytorch 'GroupNorm' layer ? – user2454869 Aug 25 '21 at 16:16
  • Update: See also https://github.com/pytorch/pytorch/issues/74604 – user2454869 Mar 29 '22 at 14:20