4

I´m trying to find a way, to calculate the Median for a given Dataframe.

val df = sc.parallelize(Seq(("a",1.0),("a",2.0),("a",3.0),("b",6.0), ("b", 8.0))).toDF("col1", "col2")

+----+----+
|col1|col2|
+----+----+
|   a| 1.0|
|   a| 2.0|
|   a| 3.0|
|   b| 6.0|
|   b| 8.0|
+----+----+

Now I want to do sth like that:
df.groupBy("col1").agg(calcmedian("col2"))

the result should look like this:

+----+------+
|col1|median|
+----+------+
|   a|   2.0|
|   b|   7.0|
+----+------+` 

therefore calcmedian() has to be a UDAF, but the problem is, the "evaluate" method of the UDAF only takes a Row, but i need the whole table to sort the values and return the median...

// Once all entries for a group are exhausted, spark will evaluate to get the final result  
def evaluate(buffer: Row) = {...}

Is this possible somehow? or is there another nice workaround? I want to stress, that i know how to calculate the median on a dataset with "one group". But i don´t want to use this algorithm in a "foreach" loop as this is inefficient!

Thank you!


edit:

that´s what i tried so far:

object calcMedian extends UserDefinedAggregateFunction {
    // Schema you get as an input 
    def inputSchema = new StructType().add("col2", DoubleType)
    // Schema of the row which is used for aggregation
    def bufferSchema = new StructType().add("col2", DoubleType)
    // Returned type
    def dataType = DoubleType
    // Self-explaining 
    def deterministic = true
    // initialize - called once for each group
    def initialize(buffer: MutableAggregationBuffer) = {
        buffer(0) = 0.0
    }

    // called for each input record of that group
    def update(buffer: MutableAggregationBuffer, input: Row) = {
        buffer(0) = input.getDouble(0)
    }
    // if function supports partial aggregates, spark might (as an optimization) comput partial results and combine them together
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1(0) = input.getDouble(0)   
    }
    // Once all entries for a group are exhausted, spark will evaluate to get the final result
    def evaluate(buffer: Row) = {
        val tile = 50
        var median = 0.0

        //PROBLEM: buffer is a Row --> I need DataFrame here???
        val rdd_sorted = buffer.sortBy(x => x)
        val c = rdd_sorted.count()
        if (c == 1){
            median = rdd_sorted.first()                
        }else{
            val index = rdd_sorted.zipWithIndex().map(_.swap)
            val last = c
            val n = (tile/ 100d) * (c*1d)
            val k = math.floor(n).toLong       
            val d = n - k
            if( k <= 0) {
                median = rdd_sorted.first()
            }else{
                if (k <= c){
                    median = index.lookup(last - 1).head
                }else{
                    if(k >= c){
                        median = index.lookup(last - 1).head
                    }else{
                        median = index.lookup(k-1).head + d* (index.lookup(k).head - index.lookup(k-1).head)
                    }
                }
            }
        }
    }   //end of evaluate
johntechendso
  • 233
  • 1
  • 3
  • 10
  • You need to `groupByKey`, transform the aggregated data to a `Buffer` there are some `UDF`s to achieve this, and then you create a UDF to compute the median. – Alberto Bonsanto Jun 02 '16 at 11:24
  • The `UserDefinedAggregateFunction` base class has many more members than just `evaluate`, which need to be implemented. The `Row` buffer passed to `evaluate` is the very final step. have you tried any implementation, and if so can you show your code so far? – mattinbits Jun 02 '16 at 11:26
  • @mattinbits: i added the code that i was thinking of so far.... – johntechendso Jun 02 '16 at 11:40
  • 1
    a) [there are already built in functions to compute approximate or exact median](http://stackoverflow.com/q/31432843/1560062) b) it is not possible to access data frame in UDAF c) computing exact median in a distributed environment is extremely inefficient simply due to definition. – zero323 Jun 02 '16 at 12:15
  • I use Spark version 1.5.2, the approxQuantile method is not avaliable! The groups are not too big, so once the DFs are groupedBy(...) it shouldn´t be that much data to shuffle around. However, if there is a solution i would like to try it. May still be more efficient than a foreach-loop – johntechendso Jun 02 '16 at 14:10
  • 1
    Then `percentile_approx` / `percentile` is. `groupBy` on data frame doesn't physically move data (no that it makes difference for shuffles here). It is `aggregate(ByKey)` equivalent which is clearly reflected by the API. One way or another you cannot access data frame inside the UDAF. – zero323 Jun 02 '16 at 16:23

1 Answers1

7

try this:

import org.apache.spark.functions._

val result = data.groupBy("col1").agg(callUDF("percentile_approx", col("col2"), lit(0.5)))
Lei Xia
  • 71
  • 1
  • 2