0

In the context of creating my own Dataset to feed into a pytorch DataLoader, I have designed a way to inherit from a class programmatically, so basically extending a class that's going to be used as a Dataset, in order to add 'custom' functionality to it. The dynamic extention works nicely. However, PyTorch doesn't like it, and when I start iterating the DataLoader based on it, it complains.

Here is a toy example for the extended class:

# Mock dataset. This has to be on a different file for some reason
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self):
        self.first = 1
        self.second = 2

    def __len__(self):
        return 1000

    def __getitem__(self, item):
        return self.first, self.second
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset


def extend_class(base_class):
    class B(base_class):
        def hello(self):
            print('Yo!')

    return B


if __name__ == '__main__':
    a = MyDataset()
    dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)
    first, second = next(iterator)  # this works ok

    extended_class = extend_class(MyDataset)

    b = extended_class()
    b.hello() # this works!

    dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)

    iterator = iter(dataloader)  # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'

    first, second = next(iterator)

Any workaround to do this is appreciated!

Vaaal88
  • 591
  • 1
  • 7
  • 25
  • This is a limitation of `pickle` in general, see the linked dupe for some workarounds. Is there any particular reason for extending dynamically? As long as the class is visible in the module namespace you won't have issues to begin with. – tzaman Jun 26 '20 at 13:27
  • I have an abstract class Generator, and other classes that inherit from it, let's say GenA and GenB. Generator has some abstract methods, for example _get_labels_, and GenA and GenB implements their own version of them. I would like to use this extended class to inject some special behaviour, overriding the standard behaviour for some of the abstract method. – Vaaal88 Jun 26 '20 at 14:45
  • Is there any reason you can't do that with standard inheritance or mixins? If you define `class B` at module scope and do something like `class Extended(MyDataset, B)` it should just work. – tzaman Jun 26 '20 at 15:04
  • This is the approach I ended up with, but I don't really like it. The problem is that normally B would extend MyDataset, so it would often call super().some_fun before doing other stuff after that function. This also makes it simple to combine extensions to be called one after the other. By using the `class Extended(MyDataset, B)` method I can't actually do that, as I coulnd't execute both MyDataset.some_fun() and B.some_fun() afaik – Vaaal88 Jun 27 '20 at 04:44
  • You could flip the order around to `Extended(B, MyDataset)` and then have `B.some_fun()` also call `super().some_fun()` at the appropriate time. That way `Extended` will call `B` which will call `MyDataset`. – tzaman Jun 28 '20 at 21:38
  • 1
    Could you make this an answer ? – Vaaal88 Jun 29 '20 at 07:41
  • 1
    That's okay, reopening is a hassle and the answer isn't really addressing the question as originally asked anymore anyway. Glad I could help, though! – tzaman Jun 29 '20 at 09:52

0 Answers0