0

Suppose I have a MNIST dataset in this way.

df = pd.read_csv('data/train.csv')
data = df.loc[df['label'].isin([1,6])]

I am trying to select only those rows whose column ['label'] == 1 or 6.

But, I am want to get only 500 rows from each column ['label']

How do I do it?

oya163
  • 1,371
  • 2
  • 16
  • 20

2 Answers2

2

You can group them and select the number you want for each value:

data = df.loc[df['label'].isin([1,6])].groupby('label').head(500)
Gerges
  • 6,269
  • 2
  • 22
  • 44
0

Use groupby first then filer i.e

ndf= df.groupby('label').head(500)
data = ndf.loc[ndf['label'].isin([1,6])]
Bharath M Shetty
  • 30,075
  • 6
  • 57
  • 108