So I have a use case where I need to keep track of the maximum number for each row while processing a DataFrame using flatMap
. How I was doing it is by creating a mutable map and updating the value associated with a key when a new maximum is found for the key. For example like this
df.flatMap {
...
if (!map.contains(key) || (map.contains(key) && map(key) < max)) map(key) = max
...
}
I did this because it is possible to do this in Scala.
scala> val map1 = collection.mutable.Map[Int, Int]()
map1: scala.collection.mutable.Map[Int,Int] = Map()
scala> Seq(1, 2, 3, 4).flatMap { v => map1(v) = v; Some(v + 1) }
res1: Seq[Int] = List(2, 3, 4, 5)
scala> map1
res2: scala.collection.mutable.Map[Int,Int] = Map(2 -> 2, 4 -> 4, 1 -> 1, 3 -> 3)
However, I found out later that Spark does not behave the same way.
scala> val map2 = collection.mutable.Map[Int, Int]()
map2: scala.collection.mutable.Map[Int,Int] = Map()
scala> sc.parallelize(Array(1, 2, 3, 4)).flatMap { v => map2(v) = v; println(map2); Some(map2(v) + 1) }.collect
...
Map(4 -> 4)
Map(2 -> 2)
Map(1 -> 1)
Map(3 -> 3)
...
res3: Array[Int] = Array(2, 3, 4, 5)
scala> map2
res4: scala.collection.mutable.Map[Int,Int] = Map()
Is there a way this can be replicated in Spark?