11

I have built a Dataset, where I'm doing various checks on the images I'm loading. I'm then passing this DataSet to a DataLoader.

In my DataSet class I'm returning the sample as None if a picture fails my checks and i have a custom collate_fn function which removes all Nones from the retrieved batch and returns the remaining valid samples.

However at this point the returned batch can be of varying size. Is there a way to tell the collate_fn to keep sourcing data until the batch size meets a certain length?

class DataSet():
     def __init__(self, example):
          # initialise dataset
          # load csv file and image directory
          self.example = example
     def __getitem__(self,idx):
          # load one sample
          # if image is too dark return None
          # else 
          # return one image and its equivalent label

dataset = Dataset(csv_file='../', image_dir='../../')

dataloader = DataLoader(dataset , batch_size=4,
                        shuffle=True, num_workers=1, collate_fn = my_collate )

def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}]
    batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}]
    # I want len(G) = 4
    # so how to sample another dataset entry?
    return torch.utils.data.dataloader.default_collate(batch) 
Brian Formento
  • 731
  • 2
  • 9
  • 24

6 Answers6

9

There are 2 hacks that can be used to sort out the problem, choose one way:

By using the original batch sample Fast option:

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        for i in range(diff):
            batch = batch + batch[:diff]
    return torch.utils.data.dataloader.default_collate(batch)

Otherwise just load another sample from dataset at random Better option:

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # source all the required samples from the original dataset at random
        diff = len_batch - len(batch)
        for i in range(diff):
            batch.append(dataset[np.random.randint(0, len(dataset))])

    return torch.utils.data.dataloader.default_collate(batch)
Brian Formento
  • 731
  • 2
  • 9
  • 24
  • 3
    How would you construct to the dataloader collate_fn argument so that dataset is in scope? – bw4sz May 14 '21 at 19:48
  • 3
    Thanks for the code! I think it should also be supported in the "Better option" that new samples might also be None. So something like a while loop should be there, I guess. – Artem S Sep 03 '21 at 19:24
  • 1
    https://stackoverflow.com/a/67583699/8878627 solves @ArtemS's problem – S P Sharan Feb 10 '22 at 14:04
8

This worked for me, because sometimes even those random values are None.

def my_collate(batch):
    len_batch = len(batch)
    batch = list(filter(lambda x: x is not None, batch))

    if len_batch > len(batch):                
        db_len = len(dataset)
        diff = len_batch - len(batch)
        while diff != 0:
            a = dataset[np.random.randint(0, db_len)]
            if a is None:                
                continue
            batch.append(a)
            diff -= 1

    return torch.utils.data.dataloader.default_collate(batch)
Mavic More
  • 99
  • 1
  • 3
5

[Edit] Updated version of the code snipped from below can be found here https://github.com/project-lighter/lighter/blob/main/lighter/utils/collate.py

Thanks Brian Formento both for asking and giving the ideas on how to solve it. As mentioned already, the Best option that replaces bad examples with new ones has two problems:

  1. The newly sampled examples could also be corrupted;
  2. The dataset wasn't in scope.

Here's a solution to both of them - issue 1 is solved with a recursive call, and issue 2 by creating a partial function of the collate function with dataset fixed in place.

import random
import torch


def collate_fn_replace_corrupted(batch, dataset):
    """Collate function that allows to replace corrupted examples in the
    dataloader. It expect that the dataloader returns 'None' when that occurs.
    The 'None's in the batch are replaced with another examples sampled randomly.

    Args:
        batch (torch.Tensor): batch from the DataLoader.
        dataset (torch.utils.data.Dataset): dataset which the DataLoader is loading.
            Specify it with functools.partial and pass the resulting partial function that only
            requires 'batch' argument to DataLoader's 'collate_fn' option.

    Returns:
        torch.Tensor: batch with new examples instead of corrupted ones.
    """ 
    # Idea from https://stackoverflow.com/a/57882783

    original_batch_len = len(batch)
    # Filter out all the Nones (corrupted examples)
    batch = list(filter(lambda x: x is not None, batch))
    filtered_batch_len = len(batch)
    # Num of corrupted examples
    diff = original_batch_len - filtered_batch_len
    if diff > 0:
        # Replace corrupted examples with another examples randomly
        batch.extend([dataset[random.randint(0, len(dataset)-1)] for _ in range(diff)])
        # Recursive call to replace the replacements if they are corrupted
        return collate_fn_replace_corrupted(batch, dataset)
    # Finally, when the whole batch is fine, return it
    return torch.utils.data.dataloader.default_collate(batch)

However, you can't pass this straight to the DataLoader since a collate function should only have a single argument - batch. To achieve that, we make a partial function with the dataset specified, and pass the partial function to the DataLoader.

import functools
from torch.utils.data import DataLoader


collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)
return DataLoader(dataset,
                  batch_size=batch_size,
                  num_workers=num_workers,
                  pin_memory=pin_memory,
                  collate_fn=collate_fn)
Ibrahim Hadzic
  • 101
  • 1
  • 3
4

For anyone who wishes to reject training examples on the fly, instead of using tricks to solve the issue in the collate_fn of the dataloader, one can simply use an IterableDataset and write the __iter__ and __next__ functions as follows

def __iter__(self):
    return self
def __next__(self):
    # load the next non-None example
Yash Sharma
  • 94
  • 1
  • 5
0

Why not just solve this inside the dataset class using the __ get_item__ method? instead of returning a None when data is no good, you can just ask for a different random index recursively.

class DataSet():
    def __getitem__(self, idx):
        sample = load_sample(idx)
        if is_no_good(sample):
            idx = np.random.randint(0, len(self)-1)
            sample = self[idx]
        return sample

This way you don't have to deal with batches of different sizes.

-1

For the Fast option, there is something wrong in it. Below is the fixed version.

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        batch = batch + batch[:diff] # assume diff < len(batch)
    return torch.utils.data.dataloader.default_collate(batch)
SONG
  • 1