130

I need to split my data into a training set (75%) and test set (25%). I currently do that with the code below:

X, Xt, userInfo, userInfo_train = sklearn.cross_validation.train_test_split(X, userInfo)   

However, I'd like to stratify my training dataset. How do I do that? I've been looking into the StratifiedKFold method, but doesn't let me specifiy the 75%/25% split and only stratify the training dataset.

pir
  • 5,513
  • 12
  • 63
  • 101

9 Answers9

230

[update for 0.17]

See the docs of sklearn.model_selection.train_test_split:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    stratify=y, 
                                                    test_size=0.25)

[/update for 0.17]

There is a pull request here. But you can simply do train, test = next(iter(StratifiedKFold(...))) and use the train and test indices if you want.

Martin Thoma
  • 124,992
  • 159
  • 614
  • 958
Andreas Mueller
  • 27,470
  • 8
  • 62
  • 74
  • 1
    @AndreasMueller Is there an easy way to stratify regression data? – Jordan Sep 14 '16 at 09:53
  • 4
    @Jordan nothing is implemented in scikit-learn. I don't know of a standard way. We could use percentiles. – Andreas Mueller Sep 14 '16 at 20:15
  • @AndreasMueller Have you ever seen the behavior where this method is considerably slower than the StratifiedShuffleSplit? I was using the MNIST dataset. – psiyumm Oct 15 '17 at 06:06
  • @activatedgeek that seems very weird, as train_test_split(...stratify=) is just calling StratifiedShuffleSplit and taking the first split. Feel free to open an issue on the tracker with a reproducible example. – Andreas Mueller Oct 15 '17 at 13:58
  • @AndreasMueller I actually didn't open an issue because I have a strong feeling I am doing something wrong (even though it is just 2 lines). But if I am still able to reproduce it today multiple times, I'll do that! – psiyumm Oct 15 '17 at 18:54
  • Missed this completely. Made life easier. – Mohith7548 Mar 24 '21 at 07:27
38

You can simply do it with train_test_split() method available in Scikit learn:

from sklearn.model_selection import train_test_split 
train, test = train_test_split(X, test_size=0.25, stratify=X['YOUR_COLUMN_LABEL']) 

I have also prepared a short GitHub Gist which shows how stratify option works:

https://gist.github.com/SHi-ON/63839f3a3647051a180cb03af0f7d0d9

Shayan Amani
  • 5,787
  • 1
  • 39
  • 40
33

TL;DR : Use StratifiedShuffleSplit with test_size=0.25

Scikit-learn provides two modules for Stratified Splitting:

  1. StratifiedKFold : This module is useful as a direct k-fold cross-validation operator: as in it will set up n_folds training/testing sets such that classes are equally balanced in both.

Heres some code(directly from above documentation)

>>> skf = cross_validation.StratifiedKFold(y, n_folds=2) #2-fold cross validation
>>> len(skf)
2
>>> for train_index, test_index in skf:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
...    #fit and predict with X_train/test. Use accuracy metrics to check validation performance
  1. StratifiedShuffleSplit : This module creates a single training/testing set having equally balanced(stratified) classes. Essentially this is what you want with the n_iter=1. You can mention the test-size here same as in train_test_split

Code:

>>> sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0)
>>> len(sss)
1
>>> for train_index, test_index in sss:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
>>> # fit and predict with your classifier using the above X/y train/test
Josh Noe
  • 2,664
  • 2
  • 35
  • 37
tangy
  • 3,056
  • 2
  • 25
  • 42
  • 6
    Note that as of `0.18.x`, `n_iter` should be `n_splits` for `StratifiedShuffleSplit ` - and that there's a slightly different API for it: http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html – lollercoaster Oct 31 '16 at 23:27
  • 3
    If `y` is a Pandas Series, use `y.iloc[train_index], y.iloc[test_index]` – Owlright Aug 01 '18 at 10:44
  • 1
    @Owlright I tried using a pandas dataframe and the indices that StratifiedShuffleSplit returns is not the indices in the dataframe. `dataframe index: 2,3,5` `the first split in sss:[(array([2, 1]), array([0]))]` :( – Meghna Natraj Aug 31 '18 at 22:52
  • 2
    @tangy why is this a for loop? isn't it the case that when a line `X_train, X_test = X[train_index], X[test_index]` is invoked it overrides `X_train` and `X_test`? Why then not just a single `next(sss)`? – Bartek Wójcik Sep 03 '18 at 13:39
  • If you encounter "TypeError: 'StratifiedShuffleSplit' object is not iterable", perhaps this post can help: https://stackoverflow.com/questions/53899066/what-could-be-the-reason-for-typeerror-stratifiedshufflesplit-object-is-not – DnVS Jun 07 '21 at 08:40
20

Here's an example for continuous/regression data (until this issue on GitHub is resolved).

min = np.amin(y)
max = np.amax(y)

# 5 bins may be too few for larger datasets.
bins     = np.linspace(start=min, stop=max, num=5)
y_binned = np.digitize(y, bins, right=True)

X_train, X_test, y_train, y_test = train_test_split(
    X, 
    y, 
    stratify=y_binned
)
  • Where start is min and stop is max of your continuous target.
  • If you don't set right=True then it will more or less make your max value a separate bin and your split will always fail because too few samples will be in that extra bin.
Kermit
  • 4,922
  • 4
  • 42
  • 74
Jordan
  • 1,003
  • 2
  • 12
  • 24
6

In addition to the accepted answer by @Andreas Mueller, just want to add that as @tangy mentioned above:

StratifiedShuffleSplit most closely resembles train_test_split(stratify = y) with added features of:

  1. stratify by default
  2. by specifying n_splits, it repeatedly splits the data
Max
  • 385
  • 4
  • 11
2

StratifiedShuffleSplit is done after we choose the column that should be evenly represented in all the small dataset we are about to generate. 'The folds are made by preserving the percentage of samples for each class.'

Suppose we've got a dataset 'data' with a column 'season' and we want the get an even representation of 'season' then it looks like that:

from sklearn.model_selection import StratifiedShuffleSplit
sss=StratifiedShuffleSplit(n_splits=1,test_size=0.25,random_state=0)

for train_index, test_index in sss.split(data, data["season"]):
    sss_train = data.iloc[train_index]
    sss_test = data.iloc[test_index]
Itay Guy
  • 111
  • 1
  • 6
1

As such, it is desirable to split the dataset into train and test sets in a way that preserves the same proportions of examples in each class as observed in the original dataset.

This is called a stratified train-test split.

We can achieve this by setting the “stratify” argument to the y component of the original dataset. This will be used by the train_test_split() function to ensure that both the train and test sets have the proportion of examples in each class that is present in the provided “y” array.

dev guy
  • 75
  • 7
0
#train_size is 1 - tst_size - vld_size
tst_size=0.15
vld_size=0.15

X_train_test, X_valid, y_train_test, y_valid = train_test_split(df.drop(y, axis=1), df.y, test_size = vld_size, random_state=13903) 

X_train_test_V=pd.DataFrame(X_train_test)
X_valid=pd.DataFrame(X_valid)

X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, test_size=tst_size, random_state=13903)
0

Updating @tangy answer from above to the current version of scikit-learn: 0.23.2 (StratifiedShuffleSplit documentation).

from sklearn.model_selection import StratifiedShuffleSplit

n_splits = 1  # We only want a single split in this case
sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=0)

for train_index, test_index in sss.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
Roei Bahumi
  • 3,433
  • 2
  • 20
  • 19