Here's one approach using Window function with steps as follows:
- Add row-identifying column (not needed if there is already one) and combine non-key columns (presumably many of them) into one
- Create
tmp1
with conditional nulls and tmp2
using last/rowsBetween
Window function to back-fill with the last non-null value
- Create
newcols
conditionally from cols
and tmp2
- Expand
newcols
back to individual columns using foldLeft
Note that this solution uses Window function without partitioning, thus may not work for large dataset.
val df = Seq(
("X", "x1", "x2"),
("Z", "z1", "z2"),
("Y", "", ""),
("X", "x3", "x4"),
("P", "p1", "p2"),
("Q", "q1", "q2"),
("Y", "", "")
).toDF("col1", "col2", "col3")
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val colList = df.columns.filter(_ != "col1")
val df2 = df.select($"col1", monotonically_increasing_id.as("id"),
struct(colList.map(col): _*).as("cols")
)
val df3 = df2.
withColumn( "tmp1", when($"col1" === "X", $"cols") ).
withColumn( "tmp2", last("tmp1", ignoreNulls = true).over(
Window.orderBy("id").rowsBetween(Window.unboundedPreceding, 0)
) )
df3.show
// +----+---+-------+-------+-------+
// |col1| id| cols| tmp1| tmp2|
// +----+---+-------+-------+-------+
// | X| 0|[x1,x2]|[x1,x2]|[x1,x2]|
// | Z| 1|[z1,z2]| null|[x1,x2]|
// | Y| 2| [,]| null|[x1,x2]|
// | X| 3|[x3,x4]|[x3,x4]|[x3,x4]|
// | P| 4|[p1,p2]| null|[x3,x4]|
// | Q| 5|[q1,q2]| null|[x3,x4]|
// | Y| 6| [,]| null|[x3,x4]|
// +----+---+-------+-------+-------+
val df4 = df3.withColumn( "newcols",
when($"col1" === "Y", $"tmp2").otherwise($"cols")
).select($"col1", $"newcols")
df4.show
// +----+-------+
// |col1|newcols|
// +----+-------+
// | X|[x1,x2]|
// | Z|[z1,z2]|
// | Y|[x1,x2]|
// | X|[x3,x4]|
// | P|[p1,p2]|
// | Q|[q1,q2]|
// | Y|[x3,x4]|
// +----+-------+
val dfResult = colList.foldLeft( df4 )(
(accDF, c) => accDF.withColumn(c, df4(s"newcols.$c"))
).drop($"newcols")
dfResult.show
// +----+----+----+
// |col1|col2|col3|
// +----+----+----+
// | X| x1| x2|
// | Z| z1| z2|
// | Y| x1| x2|
// | X| x3| x4|
// | P| p1| p2|
// | Q| q1| q2|
// | Y| x3| x4|
// +----+----+----+
[UPDATE]
For Spark 1.x, last(colName, ignoreNulls)
isn't available in the DataFrame API. A work-around is to revert to use Spark SQL which supports ignore-null in its last() method:
df2.
withColumn( "tmp1", when($"col1" === "X", $"cols") ).
createOrReplaceTempView("df2table")
// might need to use registerTempTable("df2table") instead
val df3 = spark.sqlContext.sql("""
select col1, id, cols, tmp1, last(tmp1, true) over (
order by id rows between unbounded preceding and current row
) as tmp2
from df2table
""")