0

I am using StratifiedShuffleSplit in python with n_splits=1,

I do not understand why I still need a for loop to get the output? Why the following code does not work?

split=StratifiedShuffleSplit(n_splits=1,test_size=0.2,random_state=42) 
train_index, test_index = split.split(housing, housing["income_cat"])

Here is the original code

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]
Silly Freak
  • 4,061
  • 1
  • 36
  • 58
Ehsan
  • 149
  • 1
  • 1
  • 9
  • could you add the code you'd need for `n_splits=2`, or the working code for `n_splits=1`? That may help us answer. – Silly Freak Dec 01 '18 at 00:14
  • My feeling says that the result is an array, and that a destructuring assignment (`split, = ...` - note the comma) will be the way to go – Silly Freak Dec 01 '18 at 00:16
  • 1
    Because `split()` is a generator object in python using `yield`. So even when you have only single split, it needs to be iterated. See [this question](https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do) – Vivek Kumar Dec 01 '18 at 05:45
  • @SillyFreak: I added the original code to my problem statement. – Ehsan Dec 24 '18 at 15:08

1 Answers1

2

As @Vivek Kumar commented, the split.split() call in line two of your code returns an iterable (most likely a generator, not a list or something like that). Your non-working example tries to use the return value as if it wasn't.

Let's look with what kind of data your loop consumes:

for train_index, test_index in ...:
    ...

The for loop obviously requires an iterable. In addition, the train_index, test_index "destructures" each item in the iterable into two values, so each item has to be an iterable with exactly two elements. Usually, a tuple would be used for such cases.

So, the result of split.split() could look something like this:

[
    (a1, b1),
    (a2, b2),
    ...
]

Presumably, n_splits=1 means that there will be only one pair train_index, test_index - at least that's what you seem to claim and need to verify. In that case, the result will be this:

[
    (a1, b1),
]

So only one item that is itself a tuple with two items. You now try to destructure that single item using train_index, test_index = ..., and this fails: the number of items does not match. You need to first extract the tuple.

There are two basic ways to get the tuple:

pair = split.split(...)[0]
pair, = split.split(...)

I would strongly suggest the second variant, because it fails when there is unexpectedly more than one item; the first variant would just silently discard extra items.

Then, you can destructure the tuple:

train_index, test_index = pair

Or, both in one step:

split = StratifiedShuffleSplit(n_splits=1,test_size=0.2,random_state=42) 
(train_index, test_index), = split.split(housing, housing["income_cat"])
Silly Freak
  • 4,061
  • 1
  • 36
  • 58