I wrote a class which have init
, __getitem__
and __len__
functions.
After making the __getitem__
function more complex I noticed an iteration over the object doesn't know when to stop (although __len__
working as intended).
So what is the default __iter__
doing then?
my code for example (although it shouldnt matter a lot):
class SubsetDataset(Dataset):
def __init__(self, source_dataset: Dataset, desired_classes: list):
self.source_dataset = source_dataset
self.desired_classes = desired_classes
self.index_to_sub = dict()
i = 0
for j, d in enumerate(self.source_dataset):
sampel, label = d
if label in self.desired_classes:
self.index_to_sub[i] = j
i += 1
def __getitem__(self, index):
sample, label = self.source_dataset[self.index_to_sub[index]]
return sample, label
def __len__(self):
return len(self.index_to_sub)
my __iter__
and __next__
function that solved the problem (which I didn't want to implement):
def __iter__(self):
self.n = 0
return self
def __next__(self):
if self.n >= len(self):
raise StopIteration
next = self.source_dataset[self.index_to_sub[self.n]]
self.n += 1
return next