161

I am trying to use train_test_split from package scikit Learn, but I am having trouble with parameter stratify. Hereafter is the code:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

However, I keep getting the following problem:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

Does someone have an idea what is going on? Below is the function documentation.

[...]

stratify : array-like or None (default is None)

If not None, data is split in a stratified fashion, using this as the labels array.

New in version 0.17: stratify splitting

[...]

Community
  • 1
  • 1
Daneel Olivaw
  • 2,077
  • 4
  • 15
  • 23

6 Answers6

498

This stratify parameter makes a split so that the proportion of values in the sample produced will be the same as the proportion of values provided to parameter stratify.

For example, if variable y is a binary categorical variable with values 0 and 1 and there are 25% of zeros and 75% of ones, stratify=y will make sure that your random split has 25% of 0's and 75% of 1's.

Fazzolini
  • 5,222
  • 2
  • 11
  • 11
  • 185
    This doesn't really answer the question but is super useful for just understanding how it works. Thanks a ton. – Reed Jessen Mar 01 '18 at 20:49
  • 10
    I still struggle to understand, why this stratification is necessary: If there's class in-balance in the data, wouldn't it be preserved on average when doing a random split of the data? – Holger Brandl Jun 22 '18 at 09:04
  • 22
    @HolgerBrandl it will be preserved on average; with stratify, it will be preserved for sure. – Yonatan Oct 14 '18 at 11:22
  • 13
    @HolgerBrandl with very small or very imbalanced data sets, it's quite possible that the random split could completely eliminate a class from one of the splits. – cddt Oct 20 '19 at 02:00
  • 2
    @HolgerBrandl Nice question! Maybe we could add that first, you have to split into training and test set using `stratify`. Then second, to correct imbalance you eventually need to run oversampling or undersampling on the training set. Many Sklearn classifier has a parameter called class-weight which you can set to balanced. Finally you could also take a more appropriate metric than accuracy for imbalanced dataset. Try, F1 or area under ROC. – Claude COULOMBE Oct 27 '19 at 19:02
  • isn't it violate the temporal order of time-series data? – Amin Saqi Jun 15 '20 at 16:58
  • So what's a small/large dataset? I have a fairly well-balanced dataset with shape (130000, 23). Should I be using stratify? If stratify preserves the state of a split, why isn't it on all the time? If it's enabled on a balanced or large set, we are simply preserving that state. So what's the drawback to using stratify on a large or balanced dataset? – Edison Jul 02 '22 at 09:02
  • So if using stratify, it isn't strictly necessary to use `shuffle = True` right? – Sh.A Sep 27 '22 at 07:18
109

For my future self who comes here via Google:

train_test_split is now in model_selection, hence:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

is the way to use it. Setting the random_state is desirable for reproducibility.

Martin Thoma
  • 124,992
  • 159
  • 614
  • 958
83

Scikit-Learn is just telling you it doesn't recognise the argument "stratify", not that you're using it incorrectly. This is because the parameter was added in version 0.17 as indicated in the documentation you quoted.

So you just need to update Scikit-Learn.

Borja
  • 1,411
  • 11
  • 20
  • 1
    I'm getting the same error, although I have version 0.21.2 of scikit-learn. `scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forge` – KareemJ May 26 '19 at 23:47
21

In this context, stratification means that the train_test_split method returns training and test subsets that have the same proportions of class labels as the input dataset.

X. Wang
  • 973
  • 1
  • 11
  • 21
9

The answer I can give is that stratifying preserves the proportion of how data is distributed in the target column - and depicts that same proportion of distribution in the train_test_split. Take for example, if the problem is a binary classification problem, and the target column is having proportion of 80% = yes, and 20% = no. Since there are 4 times more 'yes' than 'no' in the target column, by splitting into train and test without stratifying, we might run into the trouble of having only the 'yes' falling into our training set, and all the 'no' falling into our test set.(i.e, the training set might not have 'no' in its target column)

Hence by Stratifying, the target column for the training set has 80% of 'yes' and 20% of 'no', and also, the target column for the test set has 80% of 'yes' and 20% of 'no' respectively.

Hence, Stratify makes even distribution of the target(label) in the train and test set - just as it is distributed in the original dataset.

from sklearn.model_selection import train_test_split
X_train, y_train, X_test, y_test = train_test_split(features, target, test-size = 0.25, stratify = target, random_state = 43)
shafee
  • 15,566
  • 3
  • 19
  • 47
The_Data_Guy
  • 123
  • 1
  • 4
6

Try running this code, it "just works":

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])
Sergey Bushmanov
  • 23,310
  • 7
  • 53
  • 72
  • @user5767535 As you might see it's working on my Ubuntu machine, with `sklearn` of '0.17' version, Anaconda distribution for Python 3,5. I can only suggest checking one more time if you enter the code correctly and updating your software. – Sergey Bushmanov Jan 17 '16 at 21:07
  • 2
    @user5767535 BTW, "New in version 0.17: stratify splitting" makes me almost certain that you have to update your `sklearn`... – Sergey Bushmanov Jan 17 '16 at 21:11