4

I’m trying to calculate a rolling weighted avg over a window (partition by id1, id2 ORDER BY unixTime) in Pyspark and wanted to know if anyone had ideas on how to do this.

The rolling avg will take the current row’s value for a column, the 9 previous row values for that column and the 9 following row values for that column and weight each value based on how for it is from the row. So the current row is weighted 10x and the lag 1/lead 1 values are weighted 9x.

If none of the values are null, then the denominator for the weighted avg would be 100. The one caveat is that if there are null values, we still want to calculate a moving average (unless a little over 1/2 of the values are null).

So, for example, if the 9 values before the current val are null, the denominator would be 55. IF over 1/2 the values are null, then we would output NULL for the weighted average. We could also use the logic where we say if the denominator is less than 40 or something, output null.

I've attached a screenshot to explain what I am saying in case it is confusing, hopefully this clears things up: enter image description here

I know I could do this in sql (and I could save the data frame as a temp view), but because I have to do this rolling avg for multiple columns (same exact logic), ideally if I could do it in Pyspark I will be able to write a for loop and then do it for each column. Also, I would love to do this efficiently. I’ve read many threads about rolling averages but think this situation is slightly different.

Sorry if I am overcomplicating this, hopefully it makes sense. If this isn't easy to do efficiently, I do know how to calculate it in sql by listing lag(val, 10) over window... lag(val, 9) over window... etc. and can just go with that.

Douglas M
  • 1,035
  • 8
  • 17
WIT
  • 1,043
  • 2
  • 15
  • 32
  • Does this answer your question? : https://stackoverflow.com/questions/47622447/weighted-moving-average-in-pyspark – pissall Jul 30 '20 at 04:36
  • @pissall no, that is the post that I read but that solution means that a null value will act as a 0 -> the average will be skewed instead of removing nulls from the denominator. I believe I might be able to find a way to amend that solution and make it work for me but it'd be really inefficient – WIT Jul 30 '20 at 20:12
  • My suggestion is that you will have to tweak that answer according to your needs. – pissall Jul 31 '20 at 03:07

1 Answers1

6

IIUC, one way you can try is to use the Window function collect_list, sort the list, find the position idx of the current Row using array_position (require Spark 2.4+) and then calculate the weight based on this, let's use an example Window of size=7 (or N=3 in below code):

from pyspark.sql.functions import expr, sort_array, collect_list, struct
from pyspark.sql import Window

df = spark.createDataFrame([
    (0, 0.5), (1, 0.6), (2, 0.65), (3, 0.7), (4, 0.77),
    (5, 0.8), (6, 0.7), (7, 0.9), (8, 0.99), (9, 0.95)
], ["time", "val"])

N = 3

w1 = Window.partitionBy().orderBy('time').rowsBetween(-N,N)

# note that the index for array_position is 1-based, `i` in transform function is 0-based
df1 = df.withColumn('data', sort_array(collect_list(struct('time','val')).over(w1))) \
    .withColumn('idx', expr("array_position(data, (time,val))-1")) \
    .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))"))

df1.show(truncate=False)
+----+----+-------------------------------------------------------------------------+---+----------------------+
|time|val |data                                                                     |idx|weights               |
+----+----+-------------------------------------------------------------------------+---+----------------------+
|0   |0.5 |[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7]]                                |0  |[10, 9, 8, 7]         |
|1   |0.6 |[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77]]                     |1  |[9, 10, 9, 8, 7]      |
|2   |0.65|[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8]]           |2  |[8, 9, 10, 9, 8, 7]   |
|3   |0.7 |[[0, 0.5], [1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7]] |3  |[7, 8, 9, 10, 9, 8, 7]|
|4   |0.77|[[1, 0.6], [2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9]] |3  |[7, 8, 9, 10, 9, 8, 7]|
|5   |0.8 |[[2, 0.65], [3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99]]|3  |[7, 8, 9, 10, 9, 8, 7]|
|6   |0.7 |[[3, 0.7], [4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]|3  |[7, 8, 9, 10, 9, 8, 7]|
|7   |0.9 |[[4, 0.77], [5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]          |3  |[7, 8, 9, 10, 9, 8]   |
|8   |0.99|[[5, 0.8], [6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]                     |3  |[7, 8, 9, 10, 9]      |
|9   |0.95|[[6, 0.7], [7, 0.9], [8, 0.99], [9, 0.95]]                               |3  |[7, 8, 9, 10]         |
+----+----+-------------------------------------------------------------------------+---+----------------------+

Then we can use SparkSQL builtin function aggregate to calculate the sum of weights and the weighted values:

N = 9

w1 = Window.partitionBy().orderBy('time').rowsBetween(-N,N)

df_new = df.withColumn('data', sort_array(collect_list(struct('time','val')).over(w1))) \
    .withColumn('idx', expr("array_position(data, (time,val))-1")) \
    .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))"))\
    .withColumn('sum_weights', expr("aggregate(weights, 0D, (acc,x) -> acc+x)")) \
    .withColumn('weighted_val', expr("""
      aggregate(
        zip_with(data,weights, (x,y) -> x.val*y),
        0D, 
        (acc,x) -> acc+x,
        acc -> acc/sum_weights
      )""")) \
    .drop("data", "idx", "sum_weights", "weights")

df_new.show()
+----+----+------------------+
|time| val|      weighted_val|
+----+----+------------------+
|   0| 0.5|0.6827272727272726|
|   1| 0.6|0.7001587301587302|
|   2|0.65|0.7169565217391304|
|   3| 0.7|0.7332876712328767|
|   4|0.77|            0.7492|
|   5| 0.8|0.7641333333333333|
|   6| 0.7|0.7784931506849315|
|   7| 0.9|0.7963768115942028|
|   8|0.99|0.8138095238095238|
|   9|0.95|0.8292727272727273|
+----+----+------------------+

Notes:

  • you can calculate multiple columns by setting struct('time','val1', 'val2') in the first line of calculating df_new and then adjust the corresponding calculation of idx and x.val*y in weighted_val etc.

  • to set NULL when less than half values are not able to be collected, add a IF(size(data) <= 9, NULL, ...) or IF(sum_weights < 40, NULL, ...) statement to the following:

      df_new = df.withColumn(...) \
      ...
          .withColumn('weighted_val', expr(""" IF(size(data) <= 9, NULL, 
            aggregate( 
              zip_with(data,weights, (x,y) -> x.val*y), 
              0D,  
              (acc,x) -> acc+x, 
              acc -> acc/sum_weights 
           ))""")) \
          .drop("data", "idx", "sum_weights", "weights")
    

EDIT: for multiple columns, you can try:

cols = ['val1', 'val2', 'val3']

# function to set SQL expression to calculate weighted values for the field `val`
weighted_vals = lambda val: """
    aggregate(
      zip_with(data,weights, (x,y) -> x.{0}*y),
      0D,
      (acc,x) -> acc+x,
      acc -> acc/sum_weights
    ) as weighted_{0}
""".format(val)

df_new = df.withColumn('data', sort_array(collect_list(struct('time',*cols)).over(w1))) \
  .withColumn('idx', expr("array_position(data, (time,{}))-1".format(','.join(cols)))) \
  .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))")) \
  .withColumn('sum_weights', expr("aggregate(weights, 0D, (acc,x) -> acc+x)")) \
  .selectExpr(df.columns + [ weighted_vals(c) for c in cols ])

If the # of columns are limited, we can write up the SQL expression to calculate weighted vals with one aggregate function:

df_new = df.withColumn('data', sort_array(collect_list(struct('time',*cols)).over(w1))) \
  .withColumn('idx', expr("array_position(data, (time,{}))-1".format(','.join(cols)))) \
  .withColumn('weights', expr("transform(data, (x,i) ->  10 - abs(i-idx))")) \
  .withColumn('sum_weights', expr("aggregate(weights, 0D, (acc,x) -> acc+x)")) \
  .withColumn("vals", expr(""" 
   aggregate( 
     zip_with(data, weights, (x,y) -> (x.val1*y as val1, x.val2*y as val2)),
     (0D as val1, 0D as val2), 
     (acc,x) -> (acc.val1 + x.val1, acc.val2 + x.val2),
     acc -> (acc.val1/sum_weights as weighted_val1, acc.val2/sum_weights as weighted_val2)
   )     
   """)).select(*df.columns, "vals.*")
jxc
  • 13,553
  • 4
  • 16
  • 34
  • this is so helpful and works perfectly, thanks so much! one quick question - if I wanted to calculate for multiple columns and use the struct approach, how would adjust the idx and x.val*y because it would be x.val1 or x.val2 (for example), so it would almost be like a for each statement – WIT Aug 04 '20 at 22:28
  • your method seems a lot better for calculating multiple columns, what I was planning on doing was just listing out the columns and doing a `for col in col_list: df = df.withColumn(...)` but that is way more inefficient I assume – WIT Aug 04 '20 at 22:29
  • @WIT, just updated the post, please check the EDIT section. – jxc Aug 04 '20 at 23:20
  • 1
    awesome - this worked. thanks for your assistance in finding a clean solution here! – WIT Aug 04 '20 at 23:59
  • This is great example. If you also post the dataframe `df` such as `df = spark.createDataFrame([()],().].toDF(...)` it would be replicable and of greater use. – BhishanPoudel Aug 05 '20 at 00:19
  • @jxc one thing I am realizing with this is that if there is a null value, the sum_weights still includes that in the denominator. Is there a way to do the sum_weights in the aggregate function so that it is unique for each val – WIT Sep 14 '21 at 19:17
  • Hi, @WIT, looks you can try an SQL `IFNULL` OR `IF` statement in the calculations (1) change sum_weights to `aggregate(weights, 0D, (acc,x) -> acc+ifnull(x,0))` and then (2) in calculating weighted_val, change the first argument of the aggregate function to `zip_with(data,weights, (x,y) -> if(x.val is null, 0, x.val*y))`. – jxc Sep 14 '21 at 20:36