I have a use-case where I need to compute running sum over a partition where the running sum does not exceed a certain threshold.
For example:
// Input dataset
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 0.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 0.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 0.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 0.0 | 10.0 |
// Output requirement
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 3.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 8.0 | 10.0 |
Here, threshold for any id
will be same for all rows with that id
.
Please note that the 3rd row was skipped from summing up because the running_sum
would have exceeded the threshold
value. But 4th row was added since the running_sum
did not exceed the threshold
value.
I was able to calculate running sum without considering the threshold using window functions as follows:
final WindowSpec window = Window.partitionBy(col("id"))
.orderBy(col("created_on").asc())
.rowsBetween(Window.unboundedPreceding(), Window.currentRow());
dataset.withColumn("running_sum", sum(col("value")).over(window)).show();
// Output
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 11.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 16.0 | 10.0 |
I tried using when()
with the window and also tried lag()
, but it gave me unexpected results.
// With just sum over window
final WindowSpec window = Window.partitionBy(col("id"))
.orderBy(col("created_on").asc())
.rowsBetween(Window.unboundedPreceding(), Window.currentRow());
dataset.withColumn("running_sum",
when(sum(col("value")).over(window).leq(col("threshold")), sum(col("value")).over(window))
.otherwise(sum(col("value")).over(window).minus(col("value")))
).show();
// Output
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 3.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 11.0 | 10.0 |
// With combination of sum and lag
final WindowSpec lagWindow = Window.partitionBy(col("id")).orderBy(col("created_on").asc());
final WindowSpec window = Window.partitionBy(col("id"))
.orderBy(col("created_on").asc())
.rowsBetween(Window.unboundedPreceding(), Window.currentRow());
dataset.withColumn("running_sum",
when(sum(col("value")).over(window).leq(col("threshold")), sum(col("value")).over(window))
.otherwise(lag(col("running_sum"), 1, 0).over(lagWindow))
).show();
// Output
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 0.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 0.0 | 10.0 |
After going through some resources over the web, I came across User Defined Aggregate Functions (UDAFs) which I believe should solve my problem.
But I prefer to implement it without using UDAFs. Please let me know if there is any other way to do this or if I'm missing something in the code that I have tried.
Thanks!