I have a dataset of many string variables with different levels, and a numeric column that I need to average. After getting the average, I need to select the top N levels of each column with the highest average Target column value in descending order. Finally, for each column, I need to check if the values in the column are in the top N list, keep them if so, and if not replace the levels not in the list with the string "Other". The structure data looks something like this, but has more columns and many rows:
Column A | Column B | Column C | Target |
---|---|---|---|
A1 | B1 | C1 | 25 |
A1 | B2 | C2 | 50 |
A2 | B3 | C2 | 10 |
A2 | B3 | C2 | 15 |
Right now, I am running a for loop and grouping by each column and taking the average of the Target column, one column at a time:
import pyspark.sql.functions as F
cols_to_group = ['A','B','C']
top_n = 10
for c in cols_to_group:
agg_df = (
df
.groupBy(c)
.agg({"Target":"avg"})
.orderBy(F.col("avg(Target)").desc())
.select(c)
.limit(top_n)
)
levels = [lvl[0] for lvl in agg_df.collect()]
df = (
df
.withColumn(c, F.when(F.col(c).isin(levels), F.col(c)).otherwise(F.lit("Other"))
)
This works, but it is very slow, and this problem seems like it should parallelize fairly easily. Is there a way that I could run the groupBy/aggregation for each column in parallel and then check if the column is in the list, if so keep it, otherwise fill in the string "Other"?