1

I have a dataset of many string variables with different levels, and a numeric column that I need to average. After getting the average, I need to select the top N levels of each column with the highest average Target column value in descending order. Finally, for each column, I need to check if the values in the column are in the top N list, keep them if so, and if not replace the levels not in the list with the string "Other". The structure data looks something like this, but has more columns and many rows:

Column A Column B Column C Target
A1 B1 C1 25
A1 B2 C2 50
A2 B3 C2 10
A2 B3 C2 15

Right now, I am running a for loop and grouping by each column and taking the average of the Target column, one column at a time:

import pyspark.sql.functions as F
cols_to_group = ['A','B','C']
top_n = 10

for c in cols_to_group:
    agg_df = (
              df
             .groupBy(c)
             .agg({"Target":"avg"})
             .orderBy(F.col("avg(Target)").desc())
             .select(c)
             .limit(top_n)
             )

    levels = [lvl[0] for lvl in agg_df.collect()]

    df = (
          df
          .withColumn(c, F.when(F.col(c).isin(levels), F.col(c)).otherwise(F.lit("Other"))
         )

This works, but it is very slow, and this problem seems like it should parallelize fairly easily. Is there a way that I could run the groupBy/aggregation for each column in parallel and then check if the column is in the list, if so keep it, otherwise fill in the string "Other"?

BeefDog
  • 13
  • 2
  • This setup could be helpful. https://stackoverflow.com/questions/30214474/how-to-run-multiple-jobs-in-one-sparkcontext-from-separate-threads-in-pyspark – user238607 Aug 30 '23 at 16:43

1 Answers1

0

If the number of distinct values per column is not too big, you may explode the dataframe so that it has the form (column, value, Target), compute the means per (column, value), then aggregate by column and collect the list of means. Finally you slice that list, collect it and transform the original df.

n=2

top_n = (
    df.select(
        F.explode(
            F.array([ F.struct(
                F.lit(c).alias("column"), F.col(c).alias("value")
            ) for c in cols_to_group ])
        ).alias("s"), "Target"
      )
      .groupBy("s")
      .agg(F.avg("Target").alias("mean"))
      .groupBy("s.column")
      .agg(F.array_sort(F.collect_list(F.struct(-F.col("mean"), F.col("s.value")))).alias("list"))
      .select("column", F.explode(F.slice(F.col("list"), 1, n)).alias("s"))
      .select("column", "s.value")
      .groupBy("column")
      .agg(F.collect_list("value").alias("list"))
      .collect()
)

result = (
    df.select([F.when(F.col(i.column).isin(i.list), F.col(i.column))
                .otherwise(F.lit("Other"))
                .alias(i.column)
    for i in top_n] + ['Target'])
)
Oli
  • 9,766
  • 5
  • 25
  • 46