0

I have a PySpark dataframe with a column "group". I also have feature columns and a label column. I want to split the dataframe for each group and then train a model and end up with a dictionary where the keys are the "group" names and the values are the trained models.

This question essentially give an answer to this problem. This method is inefficient.

The obvious problem here is that it requires a full data scan for each level, so it is an expensive operation.

The answer is old and I am hoping there have been improvements in PySpark since then. For my use case I have 10k groups, with heavy skew in the data sizes. The largest group can have 1 Billion records and the smallest group can have 1 record.

Edit: As suggested here is a small reproducible example.

df = sc.createDataFrame(
    [
        ('A', 1, 0, True),
        ('A', 3, 0, False),
        ('B', 2, 2, True),
        ('B', 3, 3, True),
        ('B', 5, 2, False)
    ],
    ('group', 'feature_1', 'feature_2', 'label')
)

I can split the data as suggested in the above link:

from itertools import chain
from pyspark.sql.functions import col

groups = chain(*df.select("group").distinct().collect())
df_by_group = {group: 
  train_model(df.where(col("group").eqNullSafe(group))) for group in groups}

Where train_model is a function that takes a dataframe with columns=[feature_1, feature_2, label] and returns a trained model on that dataframe.

Heraiwa
  • 1
  • 1
  • Please give us a small [reproducible example](https://stackoverflow.com/questions/48427185/how-to-make-good-reproducible-apache-spark-examples). – cronoik Nov 24 '19 at 20:26
  • Thanks for the suggestion. I added an example, let me know if it is more clear now. – Heraiwa Nov 24 '19 at 21:10
  • use `df.sampleBy('group', fractions={group:xx})` for each group in groups, calculating xx is trivial. http://spark.apache.org/docs/2.4.0/api/python/pyspark.sql.html#pyspark.sql.DataFrame.sampleBy – jxc Nov 24 '19 at 23:09
  • How does this help with applying the training function to each subgroup @jxc? – Heraiwa Nov 25 '19 at 01:14
  • @Heraiwa, is your question to take a fraction of rows with a specific `group` value (instead of all matched rows) from a large dataset? – jxc Nov 25 '19 at 02:48
  • No, my question is regarding applying a generalized function (like a UDAF but more general) to a DataFrame. – Heraiwa Nov 25 '19 at 02:58
  • Maybe you should use a pandas udf function. First you do a repartition on the data, by the group you want. Later you define a udf function using a decorator. Later you apply this function for every partitioned data. You can find how to use UDF functions in Databricks webpage. – igorkf May 22 '20 at 20:51

0 Answers0