1

Supposing I have a Dataframe like below

Id A B C D
1 100 10 20 5
2 0 5 10 5
3 0 7 2 3
4 0 1 3 7

And the above needs to be converted to something like below

Id A B C D E
1 100 10 20 5 75
2 75 5 10 5 60
3 60 7 2 3 50
4 50 1 3 7 40

The thing works by the details provided below

  1. The data frame now has a new column E which for row 1 is calculated as col(A) - (max(col(B), col(C)) + col(D)) => 100-(max(10,20) + 5) = 75
  2. In the row with Id 2, the value of col E from row 1 is brough forward as the value for Col A
  3. So, for row 2, the column E, is determined as 75-(max(5,10) + 5) = 60
  4. Similarly in the row with Id 3, the value of A becomes 60 and the new value for col E is determined based on this

The problem is, the value of col A is dependent on the previous row's values except for the first row

Is there a possibility to solve this using windowing and lag

OneCricketeer
  • 179,855
  • 19
  • 132
  • 245
Soumya
  • 1,833
  • 5
  • 34
  • 45

2 Answers2

1

You can use collect_list function over a Window ordered by Id column and get cumulative array of structs that hold the values of A and max(B, C) + D (as field T). Then, apply aggregate to calculate column E.

Note that in this particular case you can't use lag window function as you want the get calculated values recursively.

import org.apache.spark.sql.expressions.Window

val df2 = df.withColumn(
  "tmp",
  collect_list(
    struct(col("A"), (greatest(col("B"), col("C")) + col("D")).as("T"))
  ).over(Window.orderBy("Id"))
).withColumn(
  "E",
  expr("aggregate(transform(tmp, (x, i) -> IF(i=0, x.A - x.T, -x.T)), 0, (acc, x) -> acc + x)")
).withColumn(
  "A",
  col("E") + greatest(col("B"), col("C")) + col("D")
).drop("tmp")

df2.show(false)

//+---+---+---+---+---+---+
//|Id |A  |B  |C  |D  |E  |
//+---+---+---+---+---+---+
//|1  |100|10 |20 |5  |75 |
//|2  |75 |5  |10 |5  |60 |
//|3  |60 |7  |2  |3  |50 |
//|4  |50 |1  |3  |7  |40 |
//+---+---+---+---+---+---+

You can show the intermediary column tmp to understand the logic behind the calculation.

blackbishop
  • 30,945
  • 11
  • 55
  • 76
  • Hi blackbishop, really apprecite your reply. The problem is "max(B, C) + D" is a very simpler version of the actual calculation. Actually, the calc involves multiple columns to be brought forward from prev row to current row. And the custom aggregation will become too complex to handle. Its my bad, as I was thinking that it will be kind of somehow getting the previous values using a lag and then using normal dataframes calculations on the same. But this seems to be much more complicated that what I had thought – Soumya Jan 25 '22 at 16:53
  • Hi @Soumya! This is not possible to do using simple Window functions as your calculations need to be recursive. Maybe you could ask a new question explaining in details the problem you're trying to solve. We try to answer the questions according to the elements you provide, unfortunately we can't guess if your actual task is mush more complex. – blackbishop Jan 25 '22 at 17:44
1

As blackbishop said, you can't use lag function to retrieve changing value of a column. As you're using the scala API, you can develop your own User-Defined Aggregate Function

You create the following case classes, representing the row you're currently reading and your aggregator's buffer:

case class InputRow(A: Integer, B: Integer, C: Integer, D: Integer)

case class Buffer(var E: Integer, var A: Integer)

Then you use them to define your RecursiveAggregator custom aggregator:

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder

object RecursiveAggregator extends Aggregator[InputRow, Buffer, Buffer] {
  override def zero: Buffer = Buffer(null, null)

  override def reduce(buffer: Buffer, currentRow: InputRow): Buffer = {
    buffer.A = if (buffer.E == null) currentRow.A else buffer.E
    buffer.E = buffer.A - (math.max(currentRow.B, currentRow.C) + currentRow.D)
    buffer
  }

  override def merge(b1: Buffer, b2: Buffer): Buffer = {
    throw new NotImplementedError("should be used only over ordered window")
  }

  override def finish(reduction: Buffer): Buffer = reduction

  override def bufferEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]

  override def outputEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
}

Finally you transform your RecursiveAggregator to an User-Defined aggregate function that you apply on your input dataframe:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}

val recursiveAggregator = udaf(RecursiveAggregator)

val window = Window.orderBy("Id")

val result = input
  .withColumn("computed", recursiveAggregator(col("A"), col("B"), col("C"), col("D")).over(window))
  .select("Id", "computed.A", "B", "C", "D", "computed.E")

If you take your question's dataframe as input dataframe, you get the following result dataframe:

+---+---+---+---+---+---+
|Id |A  |B  |C  |D  |E  |
+---+---+---+---+---+---+
|1  |100|10 |20 |5  |75 |
|2  |75 |5  |10 |5  |60 |
|3  |60 |7  |2  |3  |50 |
|4  |50 |1  |3  |7  |40 |
+---+---+---+---+---+---+
Vincent Doba
  • 4,343
  • 3
  • 22
  • 42
  • Thank a lot for the help. While trying to replicate Will it be possible to replicate the same in Spark2+ version. I think the "udaf" is available only in Spark3+ but unfortunately I am still stuck with older version of Spark :( – Soumya Jan 29 '22 at 05:54
  • You're right, `udaf` function doesn't exist in Spark 2. You can look at [this answer](https://stackoverflow.com/a/32101530/6807769) to use user-defined aggregate function with Spark 2. – Vincent Doba Jan 31 '22 at 08:34
  • Can anyone share any insights on how exactly this UDAF can be wrapped to work with PySpark? Hitting brick walls when trying to build a jar out of this and pushing it/registering it with PySpark :( – user9429934 Aug 05 '22 at 16:53