Referring this question: Spark / Scala: forward fill with last observation,
I am trying to reproduce the problem and to solve it.
I've created a file mre.csv:
Date,B
2015-06-01,33
2015-06-02,
2015-06-03,
2015-06-04,
2015-06-05,22
2015-06-06,
2015-06-07,
Then I read the file:
var df = spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load("D:/playground/mre.csv")
df.show()
val rows: RDD[Row] = df.orderBy($"Date").rdd
val schema = df.schema
Then I managed to solve the problem using this code:
df = df.withColumn("id",lit(1))
var spec = Window.partitionBy("id").orderBy("Date")
val df2 = df.withColumn("B", coalesce((0 to 6).map(i=>lag(df.col("B"),i,0).over(spec)): _*))
df2.show()
Output:
+-------------------+---+---+
| Date| B| id|
+-------------------+---+---+
|2015-06-01 00:00:00| 33| 1|
|2015-06-02 00:00:00| 33| 1|
|2015-06-03 00:00:00| 33| 1|
|2015-06-04 00:00:00| 33| 1|
|2015-06-05 00:00:00| 22| 1|
|2015-06-06 00:00:00| 22| 1|
|2015-06-07 00:00:00| 22| 1|
+-------------------+---+---+
The problem however is that it's all calculated in a single partition so I don't really take advantage of Spark here.
So I tried insead this code:
def notMissing(row: Row): Boolean = { !row.isNullAt(1) }
val toCarry: scala.collection.Map[Int,Option[org.apache.spark.sql.Row]] = rows
.mapPartitionsWithIndex{ case (i, iter) =>
Iterator((i, iter.filter(notMissing(_)).toSeq.lastOption)) }
.collectAsMap
val toCarryBd = sc.broadcast(toCarry)
def fill(i: Int, iter: Iterator[Row]): Iterator[Row] = {
if (iter.contains(null)) iter.map(row => Row(toCarryBd.value(i).get(1))) else iter
}
val imputed: RDD[Row] = rows
.mapPartitionsWithIndex{ case (i, iter) => fill(i, iter) }
val df2 = spark.createDataFrame(imputed, schema).toDF()
df2.show()
But the output is disappointing:
+----+---+
|Date| B|
+----+---+
+----+---+