2

I am training image classification models in Pytorch and using their default data loader to load my training data. I have a very large training dataset, so usually a couple thousand sample images per class. I've trained models with about 200k images total without issues in the past. However I've found that when have over a million images in total, the Pytorch data loader get stuck.

I believe the code is hanging when I call datasets.ImageFolder(...). When I Ctrl-C, this is consistently the output:

Traceback (most recent call last):                                                                                                 │
  File "main.py", line 412, in <module>                                                                                            │
    main()                                                                                                                         │
  File "main.py", line 122, in main                                                                                                │
    run_training(args.group, args.num_classes)                                                                                     │
  File "main.py", line 203, in run_training                                                                                        │
    train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True)                                                      │
  File "main.py", line 236, in create_dataloader                                                                                   │
    dataset = datasets.ImageFolder(directory, trans)                                                                               │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__     │
    is_valid_file=is_valid_file)                                                                                                   │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__      │
    samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)                                                     │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset  │
    for root, _, fnames in sorted(os.walk(d)):                                                                                     │
  File "/usr/lib/python3.5/os.py", line 380, in walk                                                                               │
    is_dir = entry.is_dir()                                                                                                        │
Keyboard Interrupt                                                                                                                       

I thought there might be a deadlock somewhere, however based off the stack output from Ctrl-C it doesn't look like its waiting on a lock. So then I thought that the dataloader was just slow because I was trying to load a lot more data. I let it run for about 2 days and it didn't make any progress, and in the last 2 hours of loading I checked the amount of RAM usage stayed the same. I also have been able to load training datasets with over 200k images in less than a couple hours in the past. I also tried upgrading my GCP machine to have 32 cores, 4 GPUs, and over 100GB in RAM, however it seems to be that after a certain amount of memory is loaded the data loader just gets stuck.

I'm confused how the data loader could be getting stuck while looping through the directory, and I'm still unsure if its stuck or just extremely slow. Is there some way I can change the Pytortch dataloader to be able to handle 1million+ images for training? Any debugging suggestions are also appreciated!

Thank you!

swooders
  • 141
  • 1
  • 8
  • sounds like you might have a link to a folder which is making nested loop and forever iterating same files. Try using manually os.walk(d, followlinks=True) and check if a root + fname gets repeated. Also stop if you iterate more than the amount of images you had – juvian Feb 11 '20 at 19:15

1 Answers1

5

It's not a problem with DataLoader, it's a problem with torchvision.datasets.ImageFolder and how it works (and why it works much much worse the more data you have).

It hangs on this line, as indicated by your error:

for root, _, fnames in sorted(os.walk(d)): 

Source can be found here.

Underlying problem is it keeps each path and corresponding label in giant list, see the code below (a few things removed for brevity):

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    # Iterate over all subfolders which were found previously
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target) # Create path to this subfolder
        # Assuming it is directory (which usually is the case)
        for root, _, fnames in sorted(os.walk(d, followlinks=True)):
            # Iterate over ALL files in this subdirectory
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                # Assuming it is correctly recognized as image file
                item = (path, class_to_idx[target])
                # Add to path with all images
                images.append(item)

    return images

Obviously images will contain 1 million strings (quite lengthy as well) and corresponding int for the classes which definitely is a lot and depends on RAM and CPU.

You can create your own datasets though (provided you change names of your images beforehand) so no memory will be occupied by the dataset.

Setup data structure

Your folder structure should look like this:

root
    class1
    class2
    class3
    ...

Use how many classes you have/need.

Now each class should have the following data:

class1
    0.png
    1.png
    2.png
    ...

Given that you can move on to creating datasets.

Create Datasets

Below torch.utils.data.Dataset uses PIL to open images, you could do it in another way though:

import os
import pathlib

import torch
from PIL import Image


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
        self._data = pathlib.Path(root) / folder
        self.klass = klass
        self.extension = extension
        # Only calculate once how many files are in this folder
        # Could be passed as argument if you precalculate it somehow
        # e.g. ls | wc -l on Linux
        self._length = sum(1 for entry in os.listdir(self._data))

    def __len__(self):
        # No need to recalculate this value every time
        return self._length

    def __getitem__(self, index):
        # images always follow [0, n-1], so you access them directly
        return Image.open(self._data / "{}.{}".format(str(index), self.extension))

Now you can create your datasets easily (folder structure assumed like the one above:

root = "/path/to/root/with/images"
dataset = (
    ImageDataset(root, "class0", 0)
    + ImageDataset(root, "class1", 1)
    + ImageDataset(root, "class2", 2)
)

You could add as many datasets with specified classes as you wish, do it in loop or whatever.

Finally, use torch.utils.data.DataLoader as per usual, e.g.:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83
  • 1
    Thank you so much for the explanation, and for even writing out the code for ImageDataset! Was finally able to get the models to train for the large dataset, and the rest of my models are also able to load data way faster. Thank you! – swooders Feb 15 '20 at 05:11
  • 1
    This is an excellent answer. I can't help wondering...couldn't we just remove the 'sorted()' statements from make_dataset if we wanted to make ImageFolder faster? I can understand why you have to sort the class indices...but is it really important that the image files appear with the same indices every time we run the program? Especially since we're just going to shuffle them in a DataLoader in the vast majority of applications? – DM Relenzo Sep 28 '20 at 14:46
  • @DMRelenzo No, it wouldn't help in this case. If if there was no sorting we'd still have to keep paths in memory which is problematic for large datasets. If we were to use generators we wouldn't have random access to samples (paths) which is needed by `DataLoader` (precisely passing index to `torch.utils.data.Dataset`'s `__getitem__`). The only feasible way I see is to structure your data somehow so it can be accessed randomly by creating correct paths to samples on the fly (as is shown in the example). – Szymon Maszke Sep 28 '20 at 15:09
  • @DMRelenzo You could create an `IterableDataset` also but I think this approach is easier. BTW. Yes, removing sorting would speed it up but isn't a longterm solution to this problem. – Szymon Maszke Sep 28 '20 at 15:14
  • 1
    @Szymon Maszke But they're just strings. Even a million filenames shouldn't take up that much memory, should they? A research machine probably has at least 16 GB RAM, shouldn't that be enough? I would have thought it would be more about the sorting algorithm scaling worse-than-linear as you add more filenames to sort. – DM Relenzo Sep 28 '20 at 17:14
  • 1
    @DMRelenzo yes, you are right, sorting is the most expensive operation here and removing it would help tremendously __at the cost of reproducibility__ as `os.walk` returns files and directories in arbitrary order (see [this answer](https://stackoverflow.com/a/18282401/10886420)). Hence batches from `DataLoader` would never be guaranteed to be exactly the same which would affect the results depending on when and where it was called. – Szymon Maszke Sep 29 '20 at 10:09
  • @SzymonMaszke I'm curious: why you didn't leverage datasets.ConcatDataset? I think it would perform at desired, given the presented code – stephenjfox Sep 15 '21 at 20:37
  • 1
    @stephenjfox I am, that’s what + does in case of Dataset (less known feature). Or maybe you meant something else? – Szymon Maszke Sep 16 '21 at 11:52
  • @SzymonMaszke I overlooked that line in datasets.py. My mistake. You've saved me a few dozen characters around my codebase h/t – stephenjfox Sep 16 '21 at 15:02