1

I'm using the Omniglot dataset, which is a set of 19,280 images, each which is 105 x 105 (grayscale).

I defined a custom Dataset class with the following transform:

class OmniglotDataset(Dataset):

    def __init__(self, X, transform=None):
        self.X = X
        self.transform = transform

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img = self.X[idx]
        if self.transform:
            img = self.transform(img)
        return img

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

X_train.shape
(19280, 105, 105)
train_dataset = OmniglotDataset(X_train, transform=img_transform)

When I index a single image, it returns the right dimensions:

train_dataset[0].shape
torch.Size([1, 105, 105])

But when I index several images, it returns the dimensions in the wrong order (I expect 3 x 105 x 105):

train_dataset[[1,2,3]].shape
torch.Size([105, 3, 105])
doctopus
  • 5,349
  • 8
  • 53
  • 105

1 Answers1

0

You got the error because try apply transformation of single image to list:

A more convenient way to get a batch of any size is to use Dataloader:

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

omniglot = datasets.Omniglot(root='./data', background=True, download=True, transform = img_transform)

data_loader = DataLoader(omniglot, shuffle=False, batch_size = 8)
for image_batch in data_loader:
  # now image_batch contain first eight samples
  print(image_batch.shape) # torch.Size([8, 1, 105, 105]) 
  break

If you really need to get images in arbitrary order:

from operator import itemgetter

indexes = [1,3,5]
selected_samples = itemgetter(*b)(omniglot) 
Anton Ganichev
  • 2,184
  • 1
  • 18
  • 17
  • The problem is I want to index a specific list of images. Can I do it with a dataloader? I was under the impression it batches images randomly – doctopus Apr 13 '20 at 02:21
  • With shuffle = False dataloader will not mix images: (https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader) But if you really need to get images in arbitrary order, you can do it, querying dataset as a list: https://stackoverflow.com/questions/18272160/access-multiple-elements-of-list-knowing-their-index P.S. Add examples to answer – Anton Ganichev Apr 13 '20 at 07:17