I have an image classification dataset with 6 categories that I'm loading using the torchvision ImageFolder class. I have written the below to split the dataset into 3 sets in a stratified manner:
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
train_indices, test_indices, _, _ = train_test_split
(
range(len(dataset)),
dataset.targets,
stratify=dataset.targets,
test_size=0.1,
random_state=1
)
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
train_targets = [label for _, label in train_dataset]
train_indices, val_indices, _, _ = train_test_split
(
range(len(train_dataset)),
train_targets,
stratify=train_targets,
test_size=0.111,
random_state=1
)
train_dataset = Subset(dataset, train_indices)
This correctly splits the data into a 90/10/10% split. All the classes follow the same distribution between all three sets except for the 6th class. All samples belonging to the 6th class end up in the test set.
What am I doing wrong?
Thanks.