0

Can someone help me understanding what this function does?

I understand up to the line print but after that I'm a bit lost. Starting from train_data.

def stratifiedShuffleSplit_data(X, y):
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
    for train_index, test_index in sss.split(X, y):
        print("len(TRAIN):", len(train_index), "len(TEST):", len(test_index))
        print("TRAIN:", train_index, "TEST:", test_index)

        train_data = [df.loc[ind] for ind in train_index]
        test_data = [df.loc[ind] for ind in test_index]
        save_datarows(train_data, datafile+".train")
        save_datarows(test_data, datafile+".test")
ndrplz
  • 1,584
  • 12
  • 16
EMMAKENJI
  • 359
  • 2
  • 5
  • 14

1 Answers1

0

Assuming that you are using Panda package,

 pd.DataFrame.loc 

is kind of location-based indexer - This is an oversimplified version. I will post some resources that can help you understand it better.

train_data = [df.loc[ind] for ind in train_index]

Here you basically iterate over the list ind and store the respective values train_data similarly for the case of test_data

I assume that save_datarows is a custom function to store train_data into a file with extension .train

Hope this helps.

This is a really good reference for more clarification:

Selection with .loc in python

https://www.geeksforgeeks.org/python-pandas-dataframe-loc/

AshlinJP
  • 363
  • 1
  • 10