8

I am trying to define a UserDefinedAggregateFunction(UDAF) in Spark, which counts the number of occurrences for each unique values in a column of a group.

This is an example: Suppose I have a dataframe df like this,

+----+----+
|col1|col2|
+----+----+
|   a|  a1|
|   a|  a1|
|   a|  a2|
|   b|  b1|
|   b|  b2|
|   b|  b3|
|   b|  b1|
|   b|  b1|
+----+----+

I will have a UDAF DistinctValues

val func = new DistinctValues

Then I apply it to the dataframe df

val agg_value = df.groupBy("col1").agg(func(col("col2")).as("DV"))

I am expecting to have something like this:

+----+--------------------------+
|col1|DV                        |
+----+--------------------------+
|   a|  Map(a1->2, a2->1)       |
|   b|  Map(b1->3, b2->1, b3->1)|
+----+--------------------------+

So I came out with a UDAF like this,

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.LongType
import Array._

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = scala.collection.mutable.Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp.put(str, c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0)
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1.put(k ,c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[scala.collection.mutable.Map[String, LongType]](0)
  }
}

Then I have this function on my dataframe,

val func = new DistinctValues
val agg_values = df.groupBy("col1").agg(func(col("col2")).as("DV"))

It gave such error,

func: DistinctValues = $iwC$$iwC$DistinctValues@17f48a25
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 32.0 failed 4 times, most recent failure: Lost task 1.3 in stage 32.0 (TID 884, ip-172-31-22-166.ec2.internal): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map
at $iwC$$iwC$DistinctValues.update(<console>:39)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
at org.apache.spark.scheduler.Task.run(Task.scala:89)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745)

It looks like in the update(buffer: MutableAggregationBuffer, input: Row) method, the variable buffer is a immutable.Map, the program tired to cast it to mutable.Map,

But I used mutable.Map to initialize buffer variable in initialize(buffer: MutableAggregationBuffer, input:Row) method. Is it the same variable passed to update method? And also buffer is mutableAggregationBuffer, so it should be mutable, right?

Why my mutable.Map became immutable? Does anyone know what happened?

I really need a mutable Map in this function to complete the task. I know there is a workaround to create a mutable map from the immutable map, then update it. But I really want to know why the mutable one transforms to immutable one in the program automatically, it doesn't make sense to me.

Fan L.
  • 139
  • 5

2 Answers2

6

Believe it is the MapType in your StructType. buffer therefore holds a Map, which would be immutable.

You can convert it, but why don't you just leave it immutable and do this:

mp = mp + (k -> c)

to add an entry to the immutable Map?

Working example below:

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("_2", IntegerType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp = mp  + (str -> c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[Map[String, Long]](0)
    var mp2 = buffer2.getAs[Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1 = mp1 + (k -> c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[Map[String, LongType]](0)
  }
}
AmirHd
  • 10,308
  • 11
  • 41
  • 60
David Griffin
  • 13,677
  • 5
  • 47
  • 65
  • Nice catch! Hmm, the `MapyType` in `StructType`may be the case. But there is no other mutable map type in `spark.sql.types`, unless I define my own. – Fan L. Apr 14 '16 at 18:18
  • 1
    Like I said, don't -- just use an immutable `Map`. `mp = mp + (k -> c)` on an immutable `Map` gives you the same functionality as `mp.put(k, c)` on an mutable `Map` – David Griffin Apr 14 '16 at 18:20
  • `mp = mp + (k -> c)` works! I am new to scala, didn't know you could manipulate an immutable datatype like this. Thank you very much! – Fan L. Apr 14 '16 at 18:27
  • 1
    You are not manipulating it so much as creating a whole new instance based on the previous instance, and then throwing away the previous instance. But yeah, I pretty much only use immutable collections at this point -- there's really not much reason for mutable. – David Griffin Apr 14 '16 at 18:51
  • Got u! so the `mp` needs to be `var`, you are reassigning a new map to the variable `mp`. Answer accepted – Fan L. Apr 14 '16 at 19:24
  • Exactly. Works basically the same way `var foo: Int = 0; foo = foo + 1` works – David Griffin Apr 14 '16 at 19:29
  • I am getting an output where the key is always null, meaning that the input on the update function is receiving always null. My CSV file that composes the DF is exactly the same as in the example and is loaded in spark. The output `agg_values.collect.foreach(println)` is: [a,Map(null -> 3)] [b,Map(null -> 5)] What is the issue? – Alg_D Feb 14 '17 at 16:52
  • If it's mutable is there a performance advantage in modifying the map in place? for instance if you're looking to add a single key to a large map it seems more performant to add the key in place rather than creating a brand new map – nivla12345 Feb 05 '20 at 22:11
0

Late for the party. I just discovered that one can use

override def bufferSchema: StructType = StructType(List(
    StructField("map", ObjectType(classOf[mutable.Map[String, Long]]))
))

to use mutable.Map in a buffer.

colinfang
  • 20,909
  • 19
  • 90
  • 173