11

I'm little confused about how does the class StratifiedShuffleSplit of Sklearn works.

The code below is from Géron's book "Hands On Machine Learning", chapter 2, where he does a stratified sampling.

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]

Especially, what is been doing in split.split?

Thanks!

Travis
  • 1,152
  • 9
  • 25
Rafael Higa
  • 655
  • 1
  • 8
  • 17
  • Does this answer your question? [difference between StratifiedKFold and StratifiedShuffleSplit in sklearn](https://stackoverflow.com/questions/45969390/difference-between-stratifiedkfold-and-stratifiedshufflesplit-in-sklearn) – PV8 Jan 10 '20 at 07:43

2 Answers2

16

Since you did not provide a dataset, I use sklearn sample to answer this question.

Prepare dataset

# generate data
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
data = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
group_label = np.array([0, 0, 0, 1, 1, 1])

This generate a dataset data, which has 6 obseravations and 2 variables. group_label has 2 value, means group 0 and group 1. In this case, group 0 contains 3 samples, same is group 1. To be general, the group size are not need to be the same.

Create a StratifiedShuffleSplit object instance

sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
sss.get_n_splits(data, group_label)

Out:

5

In this step, you can create a instance of StratifiedShuffleSplit, you can tell the function how to split(At random_state = 0,split data 5 times,each time 50% of data will split to test set). However, it only split data when you call it in the next step.

Call the instance, and split data.

# the instance is actually a generater
type(sss.split(data, group_label))

# split data
for train_index, test_index in sss.split(data, group_label):
     print("n_split",,"TRAIN:", train_index, "TEST:", test_index)
     X_train, X_test = X[train_index], X[test_index]
     y_train, y_test = y[train_index], y[test_index]

out:

TRAIN: [5 2 3] TEST: [4 1 0]
TRAIN: [5 1 4] TEST: [0 2 3]
TRAIN: [5 0 2] TEST: [4 3 1]
TRAIN: [4 1 0] TEST: [2 3 5]
TRAIN: [0 5 1] TEST: [3 4 2]

In this step, spliter you defined in the last step will generate 5 split of data one by one. For instance, in the first split, the original data is shuffled and sample 5,2,3 is selected as train set, this is also a stratified sampling by group_label; in the second split, the data is shuffled again and sample 5,1,4 is selected as train set; etc..

Travis
  • 1,152
  • 9
  • 25
2

split.split() function returns indexes for train samples and test samples. It'll look through it for the number of cross-validation specified and will return each time train and test sample indexes using which train and test dataset can be created by filtering whole dataset.