9

Problem

I am training a deep learning model in PyTorch for binary classification, and I have a dataset containing unbalanced class proportions. My minority class makes up about 10% of the given observations. To avoid the model learning to just predict the majority class, I want to use the WeightedRandomSampler from torch.utils.data in my DataLoader.

Let's say I have 1000 observations (900 in class 0, 100 in class 1), and a batch size of 100 for my dataloader.

Without weighted random sampling, I would expect each training epoch to consist of 10 batches.

Questions

  • Will only 10 batches be sampled per epoch when using this sampler - and consequently, would the model 'miss' a large portion of the majority class during each epoch, since the minority class is now overrepresented in the training batches?
  • Will using the sampler result in more than 10 batches being sampled per epoch (meaning the same minority class observations may appear many times, and also that training would slow down)?
Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83
clueless
  • 211
  • 2
  • 3
  • 7
  • Does this answer your question? [Pytorch - how to undersample using weightedrandomsampler](https://stackoverflow.com/questions/60320232/pytorch-how-to-undersample-using-weightedrandomsampler) – iacob Jun 02 '21 at 09:31
  • Also related: https://stackoverflow.com/questions/62878940/how-to-create-a-balancing-cycling-iterator-in-pytourch – iacob Jun 02 '21 at 09:33

2 Answers2

4

A small snippet of code to use WeightedRandomSampler
First, define the function:

def make_weights_for_balanced_classes(images, nclasses):
    n_images = len(images)
    count_per_class = [0] * nclasses
    for _, image_class in images:
        count_per_class[image_class] += 1
    weight_per_class = [0.] * nclasses
    for i in range(nclasses):
        weight_per_class[i] = float(n_images) / float(count_per_class[i])
    weights = [0] * n_images
    for idx, (image, image_class) in enumerate(images):
        weights[idx] = weight_per_class[image_class]
    return weights

And after this, use it in the following way:

import torch 
dataset_train = datasets.ImageFolder(traindir)                                                                         
                                                                                
# For unbalanced dataset we create a weighted sampler                       
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes))                                                                
weights = torch.DoubleTensor(weights)                                       
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))                     
                                                                                
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True,                              
                                                             sampler = sampler, num_workers=args.workers, pin_memory=True)  
Gulzar
  • 23,452
  • 27
  • 113
  • 201
Prajot Kuvalekar
  • 5,128
  • 3
  • 21
  • 32
2

It depends on what you're after, check torch.utils.data.WeightedRandomSampler documentation for details.

There is an argument num_samples which allows you to specify how many samples will actually be created when Dataset is combined with torch.utils.data.DataLoader (assuming you weighted them correctly):

  • If you set it to len(dataset) you will get the first case
  • If you set it to 1800 (in your case) you will get the second case

Will only 10 batches be sampled per epoch when using this sampler - and consequently, would the model 'miss' a large portion of the majority class during each epoch [...]

Yes, but new samples will be returned after this epoch passes

Will using the sampler result in more than 10 batches being sampled per epoch (meaning the same minority class observations may appear many times, and also that training would slow down)?

Training would not slow down, each epoch would take longer, but convergence should be approximately the same (as less epochs will be necessary due to more data in each).

Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83
  • 1
    "If you set it to len(dataset) you will get the first case." Meaning I will be oversampling my minority class? – Joana Rocha Apr 26 '22 at 12:33
  • 1
    @JoanaRocha yes, you would be in both cases. In the first one, you would draw 1000 samples and ~500 would be a minority (oversampled 5 times) and the majority would be ~500 (under-sampled a little bit). In the second case, you would __probably__ get 900 from the majority (not under-sampled at all) and 900 from the minority (oversampled 9 times). Multiply it by many epochs and, on average, you will see every majority sample and most minority samples multiple times. – Szymon Maszke Apr 26 '22 at 12:48