A function should be executed for multiple columns in a data frame
def handleBias(df: DataFrame, colName: String, target: String = target) = {
val w1 = Window.partitionBy(colName)
val w2 = Window.partitionBy(colName, target)
df.withColumn("cnt_group", count("*").over(w2))
.withColumn("pre2_" + colName, mean(target).over(w1))
.withColumn("pre_" + colName, coalesce(min(col("cnt_group") / col("cnt_foo_eq_1")).over(w1), lit(0D)))
.drop("cnt_group")
}
This can be written nicely as shown above in spark-SQL and a for loop. However this is causing a lot of shuffles (spark apply function to columns in parallel).
A minimal example:
val df = Seq(
(0, "A", "B", "C", "D"),
(1, "A", "B", "C", "D"),
(0, "d", "a", "jkl", "d"),
(0, "d", "g", "C", "D"),
(1, "A", "d", "t", "k"),
(1, "d", "c", "C", "D"),
(1, "c", "B", "C", "D")
).toDF("TARGET", "col1", "col2", "col3TooMany", "col4")
val columnsToDrop = Seq("col3TooMany")
val columnsToCode = Seq("col1", "col2")
val target = "TARGET"
val targetCounts = df.filter(df(target) === 1).groupBy(target)
.agg(count(target).as("cnt_foo_eq_1"))
val newDF = df.join(broadcast(targetCounts), Seq(target), "left")
val result = (columnsToDrop ++ columnsToCode).toSet.foldLeft(newDF) {
(currentDF, colName) => handleBias(currentDF, colName)
}
result.drop(columnsToDrop: _*).show
How can I formulate this more efficient using RDD API? aggregateByKey
should be a good idea but is still not very clear to me how to apply it here to substitute the window functions.
(provides a bit more context / bigger example https://github.com/geoHeil/sparkContrastCoding)
edit
Initially, I started with Spark dynamic DAG is a lot slower and different from hard coded DAG which is shown below. The good thing is, each column seems to run independent /parallel. The downside is that the joins (even for a small dataset of 300 MB) get "too big" and lead to an unresponsive spark.
handleBiasOriginal("col1", df)
.join(handleBiasOriginal("col2", df), df.columns)
.join(handleBiasOriginal("col3TooMany", df), df.columns)
.drop(columnsToDrop: _*).show
def handleBiasOriginal(col: String, df: DataFrame, target: String = target): DataFrame = {
val pre1_1 = df
.filter(df(target) === 1)
.groupBy(col, target)
.agg((count("*") / df.filter(df(target) === 1).count).alias("pre_" + col))
.drop(target)
val pre2_1 = df
.groupBy(col)
.agg(mean(target).alias("pre2_" + col))
df
.join(pre1_1, Seq(col), "left")
.join(pre2_1, Seq(col), "left")
.na.fill(0)
}
This image is with spark 2.1.0, the images from Spark dynamic DAG is a lot slower and different from hard coded DAG are with 2.0.2
The DAG will be a bit simpler when caching is applied df.cache handleBiasOriginal("col1", df). ...
What other possibilities than window functions do you see to optimize the SQL? At best it would be great if the SQL was generated dynamically.