1

I wrote a class which have init, __getitem__ and __len__ functions.
After making the __getitem__ function more complex I noticed an iteration over the object doesn't know when to stop (although __len__ working as intended).

So what is the default __iter__ doing then?

my code for example (although it shouldnt matter a lot):

class SubsetDataset(Dataset):
    def __init__(self, source_dataset: Dataset, desired_classes: list):
        self.source_dataset = source_dataset
        self.desired_classes = desired_classes
        self.index_to_sub = dict()
        i = 0
        for j, d in enumerate(self.source_dataset):
          sampel, label = d
          if label in self.desired_classes:
            self.index_to_sub[i] = j
            i += 1

    def __getitem__(self, index):
        sample, label = self.source_dataset[self.index_to_sub[index]]
        return sample, label     

    def __len__(self):
        return len(self.index_to_sub)

my __iter__ and __next__ function that solved the problem (which I didn't want to implement):

    def __iter__(self):
        self.n = 0
        return self

    def __next__(self):
        if self.n >= len(self):
          raise StopIteration
        next = self.source_dataset[self.index_to_sub[self.n]]
        self.n += 1
        return next
Mogi
  • 596
  • 1
  • 6
  • 18

0 Answers0