I have a huge dataset that does not fit in memory (150G) and I'm looking for the best way to work with it in pytorch. The dataset is composed of several .npz
files of 10k samples each. I tried to build a Dataset
class
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.files = os.listdir(self.path)
self.file_length = {}
for f in self.files:
# Load file in as a nmap
d = np.load(os.path.join(self.path, f), mmap_mode='r')
self.file_length[f] = len(d['y'])
def __len__(self):
raise NotImplementedException()
def __getitem__(self, idx):
# Find the file where idx belongs to
count = 0
f_key = ''
local_idx = 0
for k in self.file_length:
if count < idx < count + self.file_length[k]:
f_key = k
local_idx = idx - count
break
else:
count += self.file_length[k]
# Open file as numpy.memmap
d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
# Actually fetch the data
X = np.expand_dims(d['X'][local_idx], axis=1)
y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
return X, y
but when a sample is actually fetched, it takes more than 30s. It looks like the entire .npz
is opened, stocked in RAM and it accessed the right index.
How to be more efficient ?
EDIT
It appears to be a misunderstading of .npz
files see post, but is there a better approach ?
SOLUTION PROPOSAL
As proposed by @covariantmonkey, lmdb can be a good choice. For now, as the problem comes from .npz
files and not memmap
, I remodelled my dataset by splitting .npz
packages files into several .npy
files. I can now use the same logic where memmap
makes all sense and is really fast (several ms to load a sample).