I use Tensorflow, but I'm writing documentation for users that will typically vary across deep learning frameworks.
When working with datasets that don't fit on the local filesystem (TB+) I sample data from a remote data store and write samples locally to a Tensorflow standardtfrecords
format.
During the first epoch of training I will have only sampled a few values, therefore an epoch of local data is very small, I train on it. On epoch 2 I re-examine what data files have been produced by my sampling subprocesses (now more) and train on the expanded set of local data files for the next epoch. Repeat the process each epoch. In this way I build up a local cache of samples and can evict older samples as I fill up the local storage. The local samples cache grows at about the time the model needs the variance the most (towards the latter part of training).
In Python/Tensorflow it's crucial that I not deserialize the data in the Python training loop process because the Python GIL can't support the data transfer rates (300-600 MB/sec, the data is raw scientific uncompressible), and thus GPU performance suffers when the Python GIL can't service the training loop fast.
Writing the samples to a tfrecords
file from subprocesses (python multiprocessing) allows tensorflow's native TFRecordsDataset
to do deserialization outside of Python and thus we sidestep the Python GIL issues, and I can saturate a GPU with high IO data rates.
I would like to know how I would address this issue in Pytorch. I'm writing about the sampling strategy that's being used, and want to provide specific recommendations to users of both Tensorflow and PyTorch, but I don't know the PyTorch preprocessing ecosystem well enough to write with sufficient detail.
Side note: the only purely Python based solution to support these data transfer rates may come in Python 3.8 with System V shared memory and multiprocessing, but I haven't tried that yet as support for it isn't quite sufficient (soon it will be). Existing multiprocessing solutions aren't sufficient because they require deserialization in the training loop process and thus lock the GIL during deserialization at high IO rates.