TL;DR
from datasets import load_dataset
from datasets import DatasetDict
ds = load_dataset("alvations/xnli-15way")
ds_train_devtest = ds['train'].train_test_split(test_size=0.2, seed=42)
ds_devtest = ds_train_devtest['test'].train_test_split(test_size=0.5, seed=42)
ds_splits = DatasetDict({
'train': ds_train_devtest['train'],
'valid': ds_devtest['train'],
'test': ds_devtest['test']
})
print("Before:\n", ds)
print("After\n", ds_splits)
[out]:
Before:
DatasetDict({
train: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 20000
})
})
After:
DatasetDict({
train: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 16000
})
valid: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 2000
})
test: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 2000
})
})
In Long
Using this dataset with only train
split as an example:
from datasets import load_dataset
ds = load_dataset("alvations/xnli-15way")
[out]:
DatasetDict({
train: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 20000
})
})
Then you can first split the 20K rows of training data into 80-20% with:
from datasets import DatasetDict
ds_train_devtest = ds['train'].train_test_split(test_size=0.2, seed=42)
Then split the 4K rows in validation-test set into 50-50%:
ds_devtest = ds_train_devtest['test'].train_test_split(test_size=0.5, seed=42)
And finally put them together as a DatasetDict
:
ds_splits = DatasetDict({
'train': ds_train_devtest['train'],
'valid': ds_devtest['train'],
'test': ds_devtest['test']
})
[out]:
DatasetDict({
train: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 16000
})
valid: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 2000
})
test: Dataset({
features: ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
num_rows: 2000
})
})
Reference: https://huggingface.co/docs/datasets/v2.12.0/en/package_reference/main_classes#datasets.Dataset.train_test_split