3

I have a DataFrame that looks like below

 ID      Date      Amount   

10001   2019-07-01   50     
10001   2019-05-01   15
10001   2019-06-25   10   
10001   2019-05-27   20
10002   2019-06-29   25
10002   2019-07-18   35
10002   2019-07-15   40

From the amount column, I'm trying to get a 4 week rolling sum based on the date column. What I mean by that is, basically I need one more column (say amount_4wk_rolling) that will have a sum of amount column for all the rows that go back 4 weeks. So if the date in the row is 2019-07-01, then the amount_4wk_rolling column value should be the sum of amount of all the rows whose date is between 2019-07-01 and 2019-06-04 (2019-07-01 minus 28 days). So the the new DataFrame would look something like this.

 ID        Date      Amount  amount_4wk_rolling
10001   2019-07-01    50       60
10001   2019-05-01    15       15
10001   2019-06-25    10       30
10001   2019-05-27    20       35
10002   2019-06-29    25       25
10002   2019-07-18    35       100
10002   2019-07-15    40       65

I have tried using window functions except it doesn't let me choose a window based on the value of a particular column

Edit:
 My data is huge...about a TB in size. Ideally, I would like to do this in spark rather that in pandas 
Riyan Mohammed
  • 247
  • 2
  • 6
  • 20

3 Answers3

3

as suggested, you can use .rolling on Date with "28d".

seems (from your example values) that you also wanted the rolling window grouped by ID.

try this:

import pandas as pd
from io import StringIO

s = """
 ID      Date      Amount   

10001   2019-07-01   50     
10001   2019-05-01   15
10001   2019-06-25   10   
10001   2019-05-27   20
10002   2019-06-29   25
10002   2019-07-18   35
10002   2019-07-15   40
"""

df = pd.read_csv(StringIO(s), sep="\s+")
df['Date'] = pd.to_datetime(df['Date'])
amounts = df.groupby(["ID"]).apply(lambda g: g.sort_values('Date').rolling('28d', on='Date').sum())
df['amount_4wk_rolling'] = df["Date"].map(amounts.set_index('Date')['Amount'])
print(df)

Output:

      ID       Date  Amount  amount_4wk_rolling
0  10001 2019-07-01      50                60.0
1  10001 2019-05-01      15                15.0
2  10001 2019-06-25      10                10.0
3  10001 2019-05-27      20                35.0
4  10002 2019-06-29      25                25.0
5  10002 2019-07-18      35               100.0
6  10002 2019-07-15      40                65.0
Adam.Er8
  • 12,675
  • 3
  • 26
  • 38
  • @RiyanMohammed oh I see. sorry, I've never really worked with spark dataframes, or a TB of data in one go for that matter :( – Adam.Er8 Jul 23 '19 at 15:30
  • Very useful - could you explain how the line `df['amount_4wk_rolling'] = df["Date"].map(amounts.set_index('Date')['Amount'])` works? – Jossy Dec 17 '21 at 18:29
2

This can be done with pandas_udf, and it looks like you want to group with 'ID', so I used it as group id.

spark = SparkSession.builder.appName('test').getOrCreate()
df = spark.createDataFrame([Row(ID=10001, d='2019-07-01', Amount=50),
                            Row(ID=10001, d='2019-05-01', Amount=15),
                            Row(ID=10001, d='2019-06-25', Amount=10),
                            Row(ID=10001, d='2019-05-27', Amount=20),
                            Row(ID=10002, d='2019-06-29', Amount=25),
                            Row(ID=10002, d='2019-07-18', Amount=35),
                            Row(ID=10002, d='2019-07-15', Amount=40)
                           ])
df = df.withColumn('date', F.to_date('d', 'yyyy-MM-dd'))
df = df.withColumn('prev_date', F.date_sub(df['date'], 28))
df.select(["ID", "prev_date", "date", "Amount"]).orderBy('date').show()
df = df.withColumn('amount_4wk_rolling', F.lit(0.0))
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def roll_udf(pdf):
    for index, row in pdf.iterrows():
        d, pd = row['date'], row['prev_date']
        pdf.loc[pdf['date']==d, 'amount_4wk_rolling'] = np.sum(pdf.loc[(pdf['date']<=d)&(pdf['date']>=pd)]['Amount'])
    return pdf

df = df.groupby('ID').apply(roll_udf)
df.select(['ID', 'date', 'prev_date', 'Amount', 'amount_4wk_rolling']).orderBy(['ID', 'date']).show()

The output:

+-----+----------+----------+------+
|   ID| prev_date|      date|Amount|
+-----+----------+----------+------+
|10001|2019-04-03|2019-05-01|    15|
|10001|2019-04-29|2019-05-27|    20|
|10001|2019-05-28|2019-06-25|    10|
|10002|2019-06-01|2019-06-29|    25|
|10001|2019-06-03|2019-07-01|    50|
|10002|2019-06-17|2019-07-15|    40|
|10002|2019-06-20|2019-07-18|    35|
+-----+----------+----------+------+

+-----+----------+----------+------+------------------+
|   ID|      date| prev_date|Amount|amount_4wk_rolling|
+-----+----------+----------+------+------------------+
|10001|2019-05-01|2019-04-03|    15|              15.0|
|10001|2019-05-27|2019-04-29|    20|              35.0|
|10001|2019-06-25|2019-05-28|    10|              10.0|
|10001|2019-07-01|2019-06-03|    50|              60.0|
|10002|2019-06-29|2019-06-01|    25|              25.0|
|10002|2019-07-15|2019-06-17|    40|              65.0|
|10002|2019-07-18|2019-06-20|    35|             100.0|
+-----+----------+----------+------+------------------+

niuer
  • 1,589
  • 2
  • 11
  • 14
0

For pyspark, you can just use Window function: sum + RangeBetween

from pyspark.sql import functions as F, Window

# skip code to initialize Spark session and dataframe

>>> df.show()
+-----+----------+------+
|   ID|      Date|Amount|
+-----+----------+------+
|10001|2019-07-01|    50|
|10001|2019-05-01|    15|
|10001|2019-06-25|    10|
|10001|2019-05-27|    20|
|10002|2019-06-29|    25|
|10002|2019-07-18|    35|
|10002|2019-07-15|    40|
+-----+----------+------+

>>> df.printSchema()
root
 |-- ID: long (nullable = true)
 |-- Date: string (nullable = true)
 |-- Amount: long (nullable = true)

win = Window.partitionBy('ID').orderBy(F.to_timestamp('Date').astype('long')).rangeBetween(-28*86400,0)

df_new = df.withColumn('amount_4wk_rolling', F.sum('Amount').over(win))

>>> df_new.show()
+------+-----+----------+------------------+
|Amount|   ID|      Date|amount_4wk_rolling|
+------+-----+----------+------------------+
|    25|10002|2019-06-29|                25|
|    40|10002|2019-07-15|                65|
|    35|10002|2019-07-18|               100|
|    15|10001|2019-05-01|                15|
|    20|10001|2019-05-27|                35|
|    10|10001|2019-06-25|                10|
|    50|10001|2019-07-01|                60|
+------+-----+----------+------------------+
jxc
  • 13,553
  • 4
  • 16
  • 34