1

I want to do a scanLeft type of operation on one column of a dataframe. Scanleft is not pararellizable, but in my case I only want to apply this function to the elements that are already in the same partition. Therefore operation can be exectued in parallel in each partition. (No data shuffling)

Consider following example:

| partitionKey  | orderColumn   | value     | scanLeft(0)(_+_)  |
|-------------- |-------------  |-------    |------------------ |
| 1             | 1             | 1         | 1                 |
| 1             | 2             | 2         | 3                 |
| 2             | 1             | 3         | 3                 |
| 2             | 2             | 4         | 7                 |
| 1             | 3             | 5         | 8                 |
| 2             | 3             | 6         | 13                |

I want to scanLeft the values within the same partition, and create a new column to store the result.

My code for now would look sth like this:

    inDataframe
      .repartition(col("partitionKey"))
      .foreachPartition{
      partition =>
        partition.map(row => row(1).asInstanceOf[Double])
      .scanLeft(0.0)(_+_)
      .foreach(println(_))
    })

This aggregates the values as I want and prints out the result, however I want to add these values as a new column of dataframe

Any idea of how to do it?

----edit---- The real use case is to calculate time-weighted rate of return (https://www.investopedia.com/terms/t/time-weightedror.asp) Expected input look sth like this:

| product   | valuation date    | daily return  |
|---------  |----------------   |-------------- |
| 1         | 2019-01-01        | 0.1           |
| 1         | 2019-01-02        | 0.2           |
| 1         | 2019-01-03        | 0.3           |
| 2         | 2019-01-01        | 0.4           |
| 2         | 2019-01-02        | 0.5           |
| 2         | 2019-01-03        | 0.6           |

I want to calculate the cumulated return per product for all dates until the current one. Dataframe is partitioned by product, and partitions are ordered by valuation date. I already wrote the aggregation fuction to pass into scanLeft:

  def chain_ret (x: Double, y: Double): Double = {
    (1 + x) * (1 + y) - 1
  }

Expected return data:

| product   | valuation date    | daily return  | cumulated return  |
|---------  |----------------   |-------------- |------------------ |
| 1         | 2019-01-01        | 0.1           | 0.1               |
| 1         | 2019-01-02        | 0.2           | 0.32              |
| 1         | 2019-01-03        | 0.3           | 0.716             |
| 2         | 2019-01-01        | 0.4           | 0.4               |
| 2         | 2019-01-02        | 0.5           | 1.1               |
| 2         | 2019-01-03        | 0.6           | 2.36              |

I already solved this issue, by filtering dataframe for given range of dates and applying and UDAF to it. (look below) It is very long and I think with scanLeft it will be much faster!

    while(endPeriod.isBefore(end)) {
      val filtered = inDataframe
        .where("VALUATION_DATE >= '" + start + "' AND VALUATION_DATE <= '" + endPeriod + "'")
      val aggregated = aggregate_returns(filtered)
        .withColumn("VALUATION_DATE", lit(Timestamp.from(endPeriod)).cast(TimestampType))
      df_ret = df_ret.union(aggregated)
      endPeriod = endPeriod.plus(1, ChronoUnit.DAYS)
    }

 def aggregate_returns(inDataframe: DataFrame): DataFrame = {
    val groupedByKey = inDataframe
      .groupBy("product")
    groupedByKey
      .agg(
        returnChain(col("RETURN_LOCAL")).as("RETURN_LOCAL_CUMUL"),
        returnChain(col("RETURN_FX")).as("RETURN_FX_CUMUL"),
        returnChain(col("RETURN_CROSS")).as("RETURN_CROSS_CUMUL"),
        returnChain(col("RETURN")).as("RETURN_CUMUL")
      )

class ReturnChain extends UserDefinedAggregateFunction{

  // Defind the schema of the input data
  override def inputSchema: StructType =
    StructType(StructField("return", DoubleType) :: Nil)

  // Define how the aggregates types will be
  override def bufferSchema: StructType = StructType(
    StructField("product", DoubleType) :: Nil
  )

  // define the return type
  override def dataType: DataType = DoubleType

  // Does the function return the same value for the same input?
  override def deterministic: Boolean = true

  // Initial values
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.toDouble
  }

  // Updated based on Input
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = (1.toDouble + buffer.getAs[Double](0)) * (1.toDouble + input.getAs[Double](0))
  }

  // Merge two schemas
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
  }

  // Output
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0) - 1.toDouble
  }
}
Aneta
  • 31
  • 3
  • Can you add some input and expected data there could be other ways to do it. – koiralo Sep 24 '19 at 12:55
  • @ShankarKoirala I added the real use-case – Aneta Sep 24 '19 at 13:21
  • you should pay attention: if tha dataframe is partitioned by `partitionKey` this does not mean that 1 partition only contains 1 `partitionKey`, but that 1 `partitionKey` is only present in 1 partition. So you would also need a `groupBy` inside `mapPartitions`.... In general, I would try to solve it with window-functions somehow – Raphael Roth Sep 24 '19 at 19:38

1 Answers1

1

foreachPartition doen't return anything, you need to use .mapPartition() instead

The difference between foreachPartition and mapPartition is the same as that between map and foreach. Look here for good explanations Foreach vs Map in Scala

  • ok, I use forEach, because I was printing out the result of scanLeft. Do you know exactly how to create a new column with mapPartitions? – Aneta Sep 24 '19 at 12:13
  • First one question. Do you really need to use repartition? – Pablo López Gallego Sep 24 '19 at 13:17
  • thanks for help! I added detailed explaination of my use-case. Actually the data comes from Cassandra db, and is stored partitioned by product. – Aneta Sep 24 '19 at 13:24