In the context of creating my own Dataset to feed into a pytorch DataLoader, I have designed a way to inherit from a class programmatically, so basically extending a class that's going to be used as a Dataset, in order to add 'custom' functionality to it. The dynamic extention works nicely. However, PyTorch doesn't like it, and when I start iterating the DataLoader based on it, it complains.
Here is a toy example for the extended class:
# Mock dataset. This has to be on a different file for some reason
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.first = 1
self.second = 2
def __len__(self):
return 1000
def __getitem__(self, item):
return self.first, self.second
import pickle
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from utils import MyDataset
def extend_class(base_class):
class B(base_class):
def hello(self):
print('Yo!')
return B
if __name__ == '__main__':
a = MyDataset()
dataloader = DataLoader(a, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader)
first, second = next(iterator) # this works ok
extended_class = extend_class(MyDataset)
b = extended_class()
b.hello() # this works!
dataloader = DataLoader(b, batch_size=4, shuffle=True, num_workers=1)
iterator = iter(dataloader) # error here: AttributeError: Can't pickle local object 'extend_class.<locals>.B'
first, second = next(iterator)
Any workaround to do this is appreciated!