1

i am doing multiple joins with the same data frame the data frames i am joining with are result of group by on my original data frame.

    listOfCols = ["a","b","c",....]
    for c in listOfCols:
        means=df.groupby(col(c)).agg(mean(target).alias(f"{c}_mean_encoding"))
        df=df.join(means,c,how="left")

this code produces more than 100000 tasks and takes forever to finish. i see in the dag a lot of shuffling happening. how can i optimize this code ?

user1450410
  • 191
  • 1
  • 13
  • it really depends how large listOfCols is, although you can try to save df every N iteration using `persist` – abiratsis Apr 21 '20 at 09:35
  • the list is pretty large, about 10 columns , some are high cardinality fields. how the persist will help? caching the smaller grouped data frame didn't help – user1450410 Apr 21 '20 at 09:38
  • it will help by saving the intermediate results to disk and simplifying the execution plan, Spark will not evaluate the persisted parts, please check the discussion [here](https://stackoverflow.com/questions/54653298/when-is-it-not-performance-practical-to-use-persist-on-a-spark-dataframe/56093531#56093531) – abiratsis Apr 21 '20 at 10:01
  • isn't it the same thing as what checkpoint() is doing ? anyway , it takes forever with this 2 options too , i guess the shuffling part is the real heavy part. will using directly sql joins in spark sql will be more optimized? – user1450410 Apr 21 '20 at 10:17

1 Answers1

1

well, after a LOT of tries and failures , i came up with the fastest solution . instead of 1.5 hours for this job it ran for 5 minutes.... i will put it here so if someone will stumble into it - he/she won't suffer as i did... the solution was to use spark sql , it must be much more optimized internally than using data frame API:

df.registerTempTable("df")
for c in listOfCols:
    left_join_string  += f" left join means_{c} on df.{c} = means_{c}.{c}"
    means = df.groupby(F.col(c)).agg(F.mean(target).alias(f"{c}_mean_encoding"))
    means.registerTempTable(f"means_{c}")

df = sqlContext.sql("SELECT * FROM df "+left_join_string)
user1450410
  • 191
  • 1
  • 13