1

I want to create several subsets of the MNIST dataset provided in Pytorch. Each subset should have different classes. What I tried was the following:

def split_MNIST(mnist_set, digits):
    dset = mnist_set
    classes = []
    indices = dset.targets == digits[0]
    classes.append(dset.classes[digits[0]])
    if len(digits) > 1:
        for digit in digits[1:]:
            idx = dset.targets == digit
            indices = indices + idx
            classes.append(dset.classes[digit])
    dset.targets = dset.targets[indices]
    dset.data = dset.data[indices]
    dset.classes = classes
    return dset


train = datasets.MNIST("../data", train=True, download=True,
                        transform=transforms.Compose([transforms.ToTensor()]))

test =datasets.MNIST("../data", train=False, download=True,
                      transform=transforms.Compose([transforms.ToTensor()]))

tr = split_MNIST(train, [1,2,3])

trainset = torch.utils.data.DataLoader(tr, batch_size=16, shuffle=True)

This works, but instead of creating a new dataset, it actually changes the original train variable. Is there a way to create a clone of the dataset instead to preserve the original one?

Martin_s
  • 21
  • 1
  • 5
  • Possibly the most straightforward is to use torch.utils.data.Subset and provide indices of the samples you want. Each subset keeps a reference to the original dataset and samples the corresponding element provided in its index list. – jodag Feb 10 '20 at 23:44
  • 2
    You can take [`copy.deepcopy`](https://docs.python.org/3/library/copy.html#copy.deepcopy) of object [ex](https://stackoverflow.com/a/3975388/6075699). in your case `dset = copy.deepcopy(mnist_set)` will work fine – Dishin H Goyani Feb 11 '20 at 05:06

1 Answers1

0

Just put dataset instantiation inside the split_MNIST func.

def split_MNIST(path2data, train, download, transform, digits):
    dset = datasets.MNIST(path2data, train=train, download=download, transform=transform)
    classes = []
    indices = dset.targets == digits[0]
    classes.append(dset.classes[digits[0]])
    if len(digits) > 1:
        for digit in digits[1:]:
            idx = dset.targets == digit
            indices = indices + idx
            classes.append(dset.classes[digit])
    dset.targets = dset.targets[indices]
    dset.data = dset.data[indices]
    dset.classes = classes
    return dset


transforms = transforms.Compose([transforms.ToTensor()])
tr = split_MNIST('../data', train=True, download=True, transform=transforms, digits=[1,2,3])

trainset = torch.utils.data.DataLoader(tr, batch_size=16, shuffle=True)
trsvchn
  • 8,033
  • 3
  • 23
  • 30