8

I have a pandas dataframe that looks like this:

| Cliid | Segment | Insert |
|-------|---------|--------|
| 001   | A       | 0      |
| 002   | A       | 0      |
| 003   | C       | 0      |
| 004   | B       | 1      |
| 005   | A       | 0      |
| 006   | B       | 0      |

I want to split it into 2 groups in a way that each group has the same composition of each variable in [Segment, Insert]. For example, each group would have 1/2 of the observations belonging to segment A, 1/6 of Insert = 1, and so on.

I've checked this answer, but it only stratifies for one variable, it won't work for more than one.

R has this function that does exactly that, but using R is not an option.

By the way, I'm using Python 3.

arthur
  • 123
  • 1
  • 5

1 Answers1

11

You can use sklearn's train_test_split function including the parameter stratify which can be used to determine the columns to be stratified.

For example:

from sklearn.model_selection import train_test_split

df_train, df_test = train_test_split(df1, test_size=0.2, stratify=df[["Segment", "Insert"]])
Jannik
  • 965
  • 2
  • 12
  • 21
  • `stratify` doesn't seem to work for multiple columns when there is no target variable. When I run your code, I get `ValueError: Found input variables with inconsistent numbers of samples: [6, 1] `. If I delete `stratify`, it works, however. – arthur Nov 25 '20 at 14:51
  • 3
    Sorry, my mistake. I changed `stratify=[["Segment", "Insert"]]` to `stratify=df[["Segment", "Insert"]]` – Jannik Nov 25 '20 at 15:03
  • 1
    May I know why getting this ValueError: "The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2". Data is binary target, when passing all string columns to do stratify, getting the above error, but passing only target column, it works, also removing stratify it works, so whether removing stratify to be considering the split is stratified based on all columns – hanzgs Sep 15 '21 at 01:08
  • whether this approach is good https://stackoverflow.com/a/51525992/11053801 – hanzgs Sep 15 '21 at 01:19