2

Here i need to find exponential moving average in spark dataframe : Table :

ab = spark.createDataFrame(
[(1,"1/1/2020", 41.0,0.5,   0.5 ,1,     '10.22'),
 (1,"10/3/2020",24.0,0.3,   0.7 ,2,     ''     ),
 (1,"21/5/2020",32.0,0.4,   0.6 ,3,     ''     ),
 (2,"3/1/2020", 51.0,0.22,  0.78,1,     '34.78'),
 (2,"10/5/2020",14.56,0.333,0.66,2,     ''     ),
 (2,"30/9/2020",17.0,0.66,  0.34,3,     ''     )],["CID","date","A","B","C","Row","SMA"] )
ab.show()

+---+---------+-----+-----+----+---+-----+
|CID|     date|    A|    B|   C| Row|  SMA|
+---+---------+-----+-----+----+---+-----+
|  1| 1/1/2020| 41.0|  0.5| 0.5|  1|10.22|
|  1|10/3/2020| 24.0|  0.3| 0.7|  2|     |
|  1|21/5/2020| 32.0|  0.4| 0.6|  3|     |
|  2| 3/1/2020| 51.0| 0.22|0.78|  1|34.78|
|  2|10/5/2020|14.56|0.333|0.66|  2|     |
|  2|30/9/2020| 17.0| 0.66|0.34|  3|     |
+---+---------+-----+-----+----+---+-----+

Expected Output  :

+---+---------+-----+-----+----+---+-----+----------+
|CID|     date|    A|    B|   C|Row|  SMA|       EMA|
+---+---------+-----+-----+----+---+-----+----------+
|  1| 1/1/2020| 41.0|  0.5| 0.5|  1|10.22|     10.22|
|  1|10/3/2020| 24.0|  0.3| 0.7|  2|     |    14.354|
|  1|21/5/2020| 32.0|  0.4| 0.6|  3|     |   21.4124|
|  2| 3/1/2020| 51.0| 0.22|0.78|  1|34.78|     34.78|
|  2|10/5/2020|14.56|0.333|0.66|  2|     |  28.04674|
|  2|30/9/2020| 17.0| 0.66|0.34|  3|     |20.7558916|
+---+---------+-----+-----+----+---+-----+----------+

Logic : For every customer if row == 1 then SMA as EMA else ( C * LAG(EMA) + A * B ) as EMA

  • Hi, interesting, not sure if this might be of interest https://stackoverflow.com/questions/52240650/pyspark-weighted-average-by-a-column – IronMan Oct 07 '20 at 16:23
  • @murtihash can you help with the above question? – nimit kothari Oct 08 '20 at 08:54
  • Would it be possible to rewrite the formula for EMA that it does not reference previous values of EMA but only the other columns? Then a [window](http://spark.apache.org/docs/3.0.0/sql-ref-syntax-qry-select-window.html) would work – werner Oct 08 '20 at 18:01
  • @werner I need to use previous row value to get a new value of ema of the current row – nimit kothari Oct 09 '20 at 04:21

1 Answers1

1

The problem here is that a freshly calculated value of a previous row is used as input for the current row. That means that it is not possible to parallelize the calculations for a single customer.

For Spark 3.0+, it is possible to get the required result with a pandas udf using grouped map

ab = spark.createDataFrame(
    [(1,"1/1/2020", 41.0,0.5,   0.5 ,1,     '10.22'),
     (1,"10/3/2020",24.0,0.3,   0.7 ,2,     ''     ),
     (1,"21/5/2020",32.0,0.4,   0.6 ,3,     ''     ),
     (2,"3/1/2020", 51.0,0.22,  0.78,1,     '34.78'),
     (2,"10/5/2020",14.56,0.333,0.66,2,     ''     ),
     (2,"30/9/2020",17.0,0.66,  0.34,3,     ''     )],\
          ["CID","date","A","B","C","Row","SMA"] ) \
    .withColumn("SMA", F.col('SMA').cast(T.DoubleType())) \
    .withColumn("date", F.to_date(F.col("date"), "d/M/yyyy"))

import pandas as pd

def calc(df: pd.DataFrame):
    # df is a pandas.DataFrame
    df = df.sort_values('date').reset_index(drop=True)
    df.loc[0, 'EMA'] = df.loc[0, 'SMA']
    for i in range(1, len(df)):
        df.loc[i, 'EMA'] = df.loc[i, 'C'] * df.loc[i-1, 'EMA'] + \
        df.loc[i, 'A'] * df.loc[i, 'B']
    return df

ab.groupBy("CID").applyInPandas(calc, 
    schema = "CID long, date date, A double, B double, C double, Row long, SMA double, EMA double")\
    .show()

Output:

+---+----------+-----+-----+----+---+-----+------------------+
|CID|      date|    A|    B|   C|Row|  SMA|               EMA|
+---+----------+-----+-----+----+---+-----+------------------+
|  1|2020-01-01| 41.0|  0.5| 0.5|  1|10.22|             10.22|
|  1|2020-03-10| 24.0|  0.3| 0.7|  2| null|            14.354|
|  1|2020-05-21| 32.0|  0.4| 0.6|  3| null|21.412399999999998|
|  2|2020-01-03| 51.0| 0.22|0.78|  1|34.78|             34.78|
|  2|2020-05-10|14.56|0.333|0.66|  2| null|          27.80328|
|  2|2020-09-30| 17.0| 0.66|0.34|  3| null|        20.6731152|
+---+----------+-----+-----+----+---+-----+------------------+

The idea is to use a Pandas dataframe for each group. This Pandas dataframe contains all values of the current partition and is ordered by date. During the iteration over the Pandas dataframe we can now access the value of EMA of the previous row (which is not possible for a Spark dataframe).

There are some caveats:

  • all rows of one partition should fit into the memory of a single executor. Partial aggregation is not possible here
  • iterating over a Pandas dataframe is discouraged
werner
  • 13,518
  • 6
  • 30
  • 45