0

Referring this question: Spark / Scala: forward fill with last observation,

I am trying to reproduce the problem and to solve it.

I've created a file mre.csv:

Date,B
2015-06-01,33
2015-06-02,
2015-06-03,
2015-06-04,
2015-06-05,22
2015-06-06,
2015-06-07,

Then I read the file:

var df = spark.read.format("csv")
  .option("header", "true")
  .option("inferSchema", "true")
  .load("D:/playground/mre.csv")

df.show()

val rows: RDD[Row] = df.orderBy($"Date").rdd
val schema = df.schema

Then I managed to solve the problem using this code:

df = df.withColumn("id",lit(1))
var spec = Window.partitionBy("id").orderBy("Date")
val df2 = df.withColumn("B", coalesce((0 to 6).map(i=>lag(df.col("B"),i,0).over(spec)): _*))

df2.show()

Output:

+-------------------+---+---+
|               Date|  B| id|
+-------------------+---+---+
|2015-06-01 00:00:00| 33|  1|
|2015-06-02 00:00:00| 33|  1|
|2015-06-03 00:00:00| 33|  1|
|2015-06-04 00:00:00| 33|  1|
|2015-06-05 00:00:00| 22|  1|
|2015-06-06 00:00:00| 22|  1|
|2015-06-07 00:00:00| 22|  1|
+-------------------+---+---+

The problem however is that it's all calculated in a single partition so I don't really take advantage of Spark here.

So I tried insead this code:

def notMissing(row: Row): Boolean = { !row.isNullAt(1) }

val toCarry: scala.collection.Map[Int,Option[org.apache.spark.sql.Row]] = rows
  .mapPartitionsWithIndex{ case (i, iter) =>
    Iterator((i, iter.filter(notMissing(_)).toSeq.lastOption)) }
  .collectAsMap

val toCarryBd = sc.broadcast(toCarry)

def fill(i: Int, iter: Iterator[Row]): Iterator[Row] = {
  if (iter.contains(null)) iter.map(row => Row(toCarryBd.value(i).get(1))) else iter
}

val imputed: RDD[Row] = rows
  .mapPartitionsWithIndex{ case (i, iter) => fill(i, iter) }

val df2 = spark.createDataFrame(imputed, schema).toDF()

df2.show()

But the output is disappointing:

+----+---+
|Date|  B|
+----+---+
+----+---+
thebluephantom
  • 16,458
  • 8
  • 40
  • 83
Alon
  • 10,381
  • 23
  • 88
  • 152
  • Hi Alon, since you're working with Dates, can we consider the dataset is small enough to be broadcasted ? The idea would be to broadcast at least rows having a defined ``B`` so you can work in a distributed way, considering date intervals to determine which value to assign to B. – baitmbarek Dec 05 '19 at 10:38

1 Answers1

0

The implementation of the fill function is wrong here. Take a look at the steps mentioned in the answer in the referred question.

def fill(i: Int, iter: Iterator[Row]): Iterator[Row] = {
  // If it is the beginning of partition and value is missing
  // extract value to fill from toCarryBd.value
  // Remember to correct for empty / only missing partitions
  // otherwise take last not-null from the current partition
}

I have implemented it below:

def notMissing(row: Row): Boolean = { !row.isNullAt(1) }

val toCarryTemp: scala.collection.Map[Int,Option[org.apache.spark.sql.Row]] = rows
  .mapPartitionsWithIndex{ case (i, iter) =>
    Iterator((i, iter.filter(notMissing(_)).toSeq.lastOption)) }
  .collectAsMap

Extract col B value from the map and traverse it to fill the value with previous partition value in case current partition has null value. If we skip this step we will end up with output like:

+-------------------+---+
|               Date|  B|
+-------------------+---+
|2015-06-01 00:00:00| 33|
|2015-06-02 00:00:00|  0|
|2015-06-03 00:00:00|  0|
|2015-06-04 00:00:00|  0|
|2015-06-05 00:00:00| 22|
|2015-06-06 00:00:00|  0|
|2015-06-07 00:00:00|  0|
+-------------------+---+
var toCarry = scala.collection.mutable.Map[Int, Int]()

for(i <- 0 until rows.getNumPartitions) {
     toCarry(i) = toCarryTemp(i) match {
         case Some(row) => row.getInt(1) 
         case None if(i > 0) => toCarry(i-1)
         case None => 0
     }
 }

val toCarryBd = sc.broadcast(toCarry)
def fillUtil(row: Row, value: Int) = {
    if(!notMissing(row)) {
        Row(row.getTimestamp(0), value)
    }
    else row
}


def fill(i: Int, iter: Iterator[Row]): Iterator[Row] = {
  val carry = toCarryBd.value(i)
  if(iter.isEmpty) iter
  else {
      val myListHead::myListTail = iter.toList
      val resultHead = fillUtil(myListHead, carry)   //only for the first index toCarry 
      var currVal = resultHead.getInt(1)             //is used, for others we maintain 
      val resultTail = myListTail.map{row =>         //curr value
        val row1 = fillUtil(row, currVal)
        currVal = row1.getInt(1)
        row1
      }
      (resultHead :: resultTail).iterator
    }
  }


val imputed: RDD[Row] = rows.mapPartitionsWithIndex{ case (i, iter) => fill(i, iter) }

val df2 = spark.createDataFrame(imputed, schema).toDF()

df2.show()

Output:

+-------------------+---+
|               Date|  B|
+-------------------+---+
|2015-06-01 00:00:00| 33|
|2015-06-02 00:00:00| 33|
|2015-06-03 00:00:00| 33|
|2015-06-04 00:00:00| 33|
|2015-06-05 00:00:00| 22|
|2015-06-06 00:00:00| 22|
|2015-06-07 00:00:00| 22|
+-------------------+---+
wypul
  • 807
  • 6
  • 9
  • Inside fill() method, inside the map, how is the value of 'row' updated? I don't see that you insert the return value of fillUtil() to a variable and 'row' isn't changed inside fillUtil(). Either I miss something or there is some magic in here. – Alon Dec 08 '19 at 13:57
  • I am going to try this. Did you try at scale? – thebluephantom Dec 08 '19 at 20:09
  • @Alon Thanks for pointing it out. I have edited the answer. The reason it was giving correct result was because `fillUtil` wasn't getting called inside map because the data size was small after order by all the timestamps were present in their own partition refer: https://stackoverflow.com/questions/53786188/number-of-dataframe-partitions-after-sorting. To test it just do `val rows: RDD[Row] = df.orderBy($"Date").rdd.repartition(2)` and it will call `fillUtil` in map. – wypul Dec 09 '19 at 06:20
  • @thebluephantom No, I have not tried it at scale. I will do when I get some time or share your findings if you test it. – wypul Dec 09 '19 at 06:26
  • I will try, but for bounty u should have as well. – thebluephantom Dec 09 '19 at 07:28
  • When I set the number of partitions to 3 I get an exception. Why? – Alon Dec 09 '19 at 10:17
  • @Alon can you tell me the exception that you got. – wypul Dec 09 '19 at 10:40
  • @Alon I suspect you must have received an empty partition on repartitioning. I have added a check for that in `fill` method. – wypul Dec 09 '19 at 10:51
  • Now there is now exception with 3 partitions, but with 4 partitions only 5.6 is 22 and 1.6 is 33 but all the rest are 0 – Alon Dec 09 '19 at 14:45
  • @Alon repartition doesn't preserve ordering across partitions. It was just to show you how code was behaving. That's why I haven't added it to my answer. This is what I get after applying `repartition(4)` - `res3: Array[String] = Array([2015-06-03 00:00:00.0,null] -> 0, [2015-06-04 00:00:00.0,null] -> 0, [2015-06-07 00:00:00.0,null] -> 0, [2015-06-02 00:00:00.0,null] -> 1, [2015-06-06 00:00:00.0,null] -> 1, [2015-06-05 00:00:00.0,22] -> 2, [2015-06-01 00:00:00.0,33] -> 3)`. – wypul Dec 09 '19 at 14:59