4

A function should be executed for multiple columns in a data frame

def handleBias(df: DataFrame, colName: String, target: String = target) = {
    val w1 = Window.partitionBy(colName)
    val w2 = Window.partitionBy(colName, target)

    df.withColumn("cnt_group", count("*").over(w2))
      .withColumn("pre2_" + colName, mean(target).over(w1))
      .withColumn("pre_" + colName, coalesce(min(col("cnt_group") / col("cnt_foo_eq_1")).over(w1), lit(0D)))
      .drop("cnt_group")
  }

This can be written nicely as shown above in spark-SQL and a for loop. However this is causing a lot of shuffles (spark apply function to columns in parallel).

A minimal example:

  val df = Seq(
    (0, "A", "B", "C", "D"),
    (1, "A", "B", "C", "D"),
    (0, "d", "a", "jkl", "d"),
    (0, "d", "g", "C", "D"),
    (1, "A", "d", "t", "k"),
    (1, "d", "c", "C", "D"),
    (1, "c", "B", "C", "D")
  ).toDF("TARGET", "col1", "col2", "col3TooMany", "col4")

  val columnsToDrop = Seq("col3TooMany")
  val columnsToCode = Seq("col1", "col2")
  val target = "TARGET"

  val targetCounts = df.filter(df(target) === 1).groupBy(target)
    .agg(count(target).as("cnt_foo_eq_1"))
  val newDF = df.join(broadcast(targetCounts), Seq(target), "left")

  val result = (columnsToDrop ++ columnsToCode).toSet.foldLeft(newDF) {
    (currentDF, colName) => handleBias(currentDF, colName)
  }

  result.drop(columnsToDrop: _*).show

How can I formulate this more efficient using RDD API? aggregateByKeyshould be a good idea but is still not very clear to me how to apply it here to substitute the window functions.

(provides a bit more context / bigger example https://github.com/geoHeil/sparkContrastCoding)

edit

Initially, I started with Spark dynamic DAG is a lot slower and different from hard coded DAG which is shown below. The good thing is, each column seems to run independent /parallel. The downside is that the joins (even for a small dataset of 300 MB) get "too big" and lead to an unresponsive spark.

handleBiasOriginal("col1", df)
    .join(handleBiasOriginal("col2", df), df.columns)
    .join(handleBiasOriginal("col3TooMany", df), df.columns)
    .drop(columnsToDrop: _*).show

  def handleBiasOriginal(col: String, df: DataFrame, target: String = target): DataFrame = {
    val pre1_1 = df
      .filter(df(target) === 1)
      .groupBy(col, target)
      .agg((count("*") / df.filter(df(target) === 1).count).alias("pre_" + col))
      .drop(target)

    val pre2_1 = df
      .groupBy(col)
      .agg(mean(target).alias("pre2_" + col))

    df
      .join(pre1_1, Seq(col), "left")
      .join(pre2_1, Seq(col), "left")
      .na.fill(0)
  }

This image is with spark 2.1.0, the images from Spark dynamic DAG is a lot slower and different from hard coded DAG are with 2.0.2 toocomplexDAG

The DAG will be a bit simpler when caching is applied df.cache handleBiasOriginal("col1", df). ...

What other possibilities than window functions do you see to optimize the SQL? At best it would be great if the SQL was generated dynamically.

caching

Community
  • 1
  • 1
Georg Heiler
  • 16,916
  • 36
  • 162
  • 292

2 Answers2

2

The main point here is to avoid unnecessary shuffles. Right now your code shuffles twice for each column you want to include and the resulting data layout cannot be reused between columns.

For simplicity I assume that target is always binary ({0, 1}) and all remaining columns you use are of StringType. Furthermore I assume that the cardinality of the columns is low enough for the results to be grouped and handled locally. You can adjust these methods to handle other cases but it requires more work.

RDD API

  • Reshape data from wide to long:

    import org.apache.spark.sql.functions._
    
    val exploded = explode(array(
      (columnsToDrop ++ columnsToCode).map(c => 
        struct(lit(c).alias("k"), col(c).alias("v"))): _*
    )).alias("level")
    
    val long = df.select(exploded, $"TARGET")
    
  • aggregateByKey, reshape and collect:

    import org.apache.spark.util.StatCounter
    
    val lookup = long.as[((String, String), Int)].rdd
      // You can use prefix partitioner (one that depends only on _._1)
      // to avoid reshuffling for groupByKey
      .aggregateByKey(StatCounter())(_ merge _, _ merge _)
      .map { case ((c, v), s) => (c, (v, s)) }
      .groupByKey
      .mapValues(_.toMap)
      .collectAsMap
    
  • You can use lookup to get statistics for individual columns and levels. For example:

    lookup("col1")("A")
    
    org.apache.spark.util.StatCounter = 
      (count: 3, mean: 0.666667, stdev: 0.471405, max: 1.000000, min: 0.000000)
    

    Gives you data for col1, level A. Based on the binary TARGET assumption this information is complete (you get count / fractions for both classes).

    You can use lookup like this to generate SQL expressions or pass it to udf and apply it on individual columns.

DataFrame API

  • Convert data to long as for RDD API.
  • Compute aggregates based on levels:

    val stats = long
      .groupBy($"level.k", $"level.v")
      .agg(mean($"TARGET"), sum($"TARGET"))
    
  • Depending on your preferences you can reshape this to enable efficient joins or convert to a local collection and similarly to the RDD solution.

cmaher
  • 5,100
  • 1
  • 22
  • 34
zero323
  • 322,348
  • 103
  • 959
  • 935
0

Using aggregateByKey A simple explanation on aggregateByKey can be found here. Basically you use two functions: One which works inside a partition and one which works between partitions.

You would need to do something like aggregate by the first column and build a data structure internally with a map for every element of the second column to aggregate and collect data there (of course you could do two aggregateByKey if you want). This will not solve the case of doing multiple runs on the code for each column you want to work with (you can do use aggregate as opposed to aggregateByKey to work on all data and put it in a map but that will probably give you even worse performance). The result would then be one line per key, if you want to move back to the original records (as window function does) you would actually need to either join this value with the original RDD or save all values internally and flatmap

I do not believe this would provide you with any real performance improvement. You would be doing a lot of work to reimplement things that are done for you in SQL and while doing so you would be losing most of the advantages of SQL (catalyst optimization, tungsten memory management, whole stage code generation etc.)

Improving the SQL

What I would do instead is attempt to improve the SQL itself. For example, the result of the column in the window function appears to be the same for all values. Do you really need a window function? You can instead do a groupBy instead of a window function (and if you really need this per record you can try to join the results. This might provide better performance as it would not necessarily mean shuffling everything twice on every step).

Assaf Mendelson
  • 12,701
  • 5
  • 47
  • 56
  • please see http://stackoverflow.com/questions/41169873/spark-dynamic-dag-is-a-lot-slower-and-different-from-hard-coded-dag as well as my edit above. Initially, I started out using group-by with joins. This led to a job not finishing in reasonable time /spar did not seem to perform any operation. Though the join solution works fine for small data I could not get it to work with many columns. Looking forward to suggestions how to improve the SQL. – Georg Heiler Jan 04 '17 at 08:17
  • I am not saying that join is necessarily the solution. What I am saying that in most cases RDD with aggregateByKey would be slower. You can go ahead and try aggregateByKey using the link I showed and the basic logic of how to implement it. – Assaf Mendelson Jan 04 '17 at 08:28
  • Meanwhile, do you see a way to not use slow window functions but still prevent usage of the join? – Georg Heiler Jan 04 '17 at 08:29
  • Also the link you showed relates to the way lineage is built. I would try to solve that instead of going to aggregateByKey – Assaf Mendelson Jan 04 '17 at 08:29
  • The main question is what are you trying to do. For example, do you need the column value for each record or just the final values? If you need the column values you can use groupby. – Assaf Mendelson Jan 04 '17 at 08:32
  • As outlined https://github.com/geoHeil/sparkContrastCoding I want to calculate the percentage for each value in a String column as of how may of each value are of TARGET == 0 or TARGET == 1 – Georg Heiler Jan 04 '17 at 08:34
  • Let us [continue this discussion in chat](http://chat.stackoverflow.com/rooms/132265/discussion-between-assaf-mendelson-and-georg-heiler). – Assaf Mendelson Jan 04 '17 at 08:52