0

I have this data frame:

-----------------------------------------------------
|  age  |  gender  | customer type | purchases | id |
+-------+----------+---------------+-----------+----|
|  38   |  female  |   type 1      |    90     |  1 |
|  35   |  female  |   type 2      |   100     |  2 |
|  71   |  male    |   type 2      |    66     |  3 |
|  68   |  female  |   type 3      |    12     |  4 |
|  26   |  male    |   type 4      |    900    |  5 |
|  55   |  male    |   type 5      |    71     |  6 |
|  27   |  male    |   type 1      |    55     |  7 |
|  ...  |   ...    |    ...        |    ...    | ...|
+-------+----------+---------------+-----------+----+

I would like to get a split of train and test like 20% test 80% train for each customer type and with a similar distribution of age and gender because for example: If I get it for type 1, 80% of female it is not a good split.

I try to use a random module with a seed but I can't get it because I don't know how could I take into account the age and sex for the split.

Thank you!!

1 Answers1

0

If your database is large enough, I don't see why taking randomly 20% of the database for testing and the other 80% for training could modify the age and gender distributions. Here is a small example how I would have done it:

#!/usr/bin/python
import numpy as np
# Generate database
N = 1000000 #size of the database
age = np.abs(np.random.randn(N)) * 30 # Normal distribution 
gender = np.random.randint(0, 100, N)<42 # 0=male and 1=female with a 42/58 repartition
customerType = np.random.randint(0, 6, N) # 5 types of customers
purchases = np.random.randint(0, 1000, N) 

# Split database in test and train, with test containing 20% of the db and train the other 80%
# Constraints: the test and train db should have the same gender and age distribution. 
testMask = np.random.randint(0, 100, N) < 20
trainMask = np.logical_not(testMask)


# check gender distribution
print("All database: %0.2f %% female %0.2f %% male" % (gender.mean()*100., (1-gender.mean())*100.))
print("Test database: %0.2f %% female %0.2f %% male" % (gender[testMask].mean()*100., (1-gender[testMask].mean())*100.))
print("Train database: %0.2f %% female %0.2f %% male" % (gender[trainMask].mean()*100., (1-gender[trainMask].mean())*100.))

# Check age distribution
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1)
ax.hist(age, bins=30, density=True, label="All db", alpha=0.3)
ax.hist(age[testMask], bins=30, density=True, label="Test db", alpha=0.3)
ax.hist(age[trainMask], bins=30, density=True, label="Train db", alpha=0.3)
ax.legend()
plt.show()
Awen
  • 58
  • 5