0

So I have a use case where I need to keep track of the maximum number for each row while processing a DataFrame using flatMap. How I was doing it is by creating a mutable map and updating the value associated with a key when a new maximum is found for the key. For example like this

df.flatMap {
  ...
  if (!map.contains(key) || (map.contains(key) && map(key) < max)) map(key) = max
  ...
}

I did this because it is possible to do this in Scala.

scala> val map1 = collection.mutable.Map[Int, Int]()
map1: scala.collection.mutable.Map[Int,Int] = Map()

scala> Seq(1, 2, 3, 4).flatMap { v => map1(v) = v; Some(v + 1) }
res1: Seq[Int] = List(2, 3, 4, 5)

scala> map1
res2: scala.collection.mutable.Map[Int,Int] = Map(2 -> 2, 4 -> 4, 1 -> 1, 3 -> 3)

However, I found out later that Spark does not behave the same way.

scala> val map2 = collection.mutable.Map[Int, Int]()
map2: scala.collection.mutable.Map[Int,Int] = Map()

scala> sc.parallelize(Array(1, 2, 3, 4)).flatMap { v => map2(v) = v; println(map2); Some(map2(v) + 1) }.collect
...
Map(4 -> 4)
Map(2 -> 2)
Map(1 -> 1)
Map(3 -> 3)
...
res3: Array[Int] = Array(2, 3, 4, 5)
scala> map2
res4: scala.collection.mutable.Map[Int,Int] = Map()

Is there a way this can be replicated in Spark?

aa8y
  • 3,854
  • 4
  • 37
  • 62

1 Answers1

1

Use Accumulators. In this case I think it should be something like (adapted from Spark accumulableCollection does not work with mutable.Map):

implicit val mapAccum =
    new AccumulableParam[mutable.Map[Int,Int], Int] {
  def addInPlace(map1: mutable.Map[Int,Int],
                 map2: mutable.Map[Int,Int])
      : mutable.Map[Int,Int] = {
    // https://stackoverflow.com/questions/7076128/best-way-to-merge-two-maps-and-sum-the-values-of-same-key
    map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k,0)) }
  }
  def addAccumulator(t1: mutable.Map[Int,Int], x: Int)
      : mutable.Map[Int,Int] = {
    t1 += (x -> t1(x) + 1)
    t1
  }
  def zero(t: mutable.Map[Int,Int])
      : mutable.Map[Int,Int] = {
    mutable.Map[Int,Int]()
  }
}

val map2 = sc.accumulable(mutable.Map[Int,Int]())

To update in the task (note you can't access the value there):

sc.parallelize(Array(1, 2, 3, 4)).flatMap { v => 
  map2 += v // calls addAccumulator and so increments the map
  Some(v + 1)
}.collect

println(map2)
Alexey Romanov
  • 167,066
  • 35
  • 309
  • 487