1

I am trying to get the sum of Revenue over the last 3 Month rows (excluding the current row) for each Client. Minimal example with current attempt in Databricks:

cols = ['Client','Month','Revenue']
df_pd = pd.DataFrame([['A',201701,100],
                   ['A',201702,101],
                   ['A',201703,102],
                   ['A',201704,103],
                   ['A',201705,104],
                   ['B',201701,201],
                   ['B',201702,np.nan],
                   ['B',201703,203],
                   ['B',201704,204],
                   ['B',201705,205],
                   ['B',201706,206],
                   ['B',201707,207]                
                  ])
df_pd.columns = cols

spark_df = spark.createDataFrame(df_pd)
spark_df.createOrReplaceTempView('df_sql')

df_out = sqlContext.sql("""
select *, (sum(ifnull(Revenue,0)) over (partition by Client
  order by Client,Month
  rows between 3 preceding and 1 preceding)) as Total_Sum3
  from df_sql
  """)
df_out.show()

+------+------+-------+----------+
|Client| Month|Revenue|Total_Sum3|
+------+------+-------+----------+
|     A|201701|  100.0|      null|
|     A|201702|  101.0|     100.0|
|     A|201703|  102.0|     201.0|
|     A|201704|  103.0|     303.0|
|     A|201705|  104.0|     306.0|
|     B|201701|  201.0|      null|
|     B|201702|    NaN|     201.0|
|     B|201703|  203.0|       NaN|
|     B|201704|  204.0|       NaN|
|     B|201705|  205.0|       NaN|
|     B|201706|  206.0|     612.0|
|     B|201707|  207.0|     615.0|
+------+------+-------+----------+

As you can see, if a null value exists anywhere in the 3 month window, a null value is returned. I would like to treat nulls as 0, hence the ifnull attempt, but this does not seem to work. I have also tried a case statement to change NULL to 0, with no luck.

Barmar
  • 741,623
  • 53
  • 500
  • 612
GivenX
  • 495
  • 1
  • 8
  • 17

2 Answers2

0

Just coalesce outside sum:

df_out = sqlContext.sql("""
  select *, coalesce(sum(Revenue) over (partition by Client
  order by Client,Month
  rows between 3 preceding and 1 preceding)), 0) as Total_Sum3
  from df_sql
 """)
0

It is Apache Spark, my bad! (am working in Databricks and I thought it was MySQL under the hood). Is it too late to change the title?

@Barmar, you are right in that IFNULL() doesn't treat NaN as null. I managed to figure out the fix thanks to @user6910411 from here: SO link. I had to change the numpy NaNs to spark nulls. The correct code from after the sample df_pd is created:

spark_df = spark.createDataFrame(df_pd)

from pyspark.sql.functions import isnan, col, when

#this converts all NaNs in numeric columns to null:
spark_df = spark_df.select([
    when(~isnan(c), col(c)).alias(c) if t in ("double", "float") else c 
    for c, t in spark_df.dtypes])

spark_df.createOrReplaceTempView('df_sql')

df_out = sqlContext.sql("""
select *, (sum(ifnull(Revenue,0)) over (partition by Client
  order by Client,Month
  rows between 3 preceding and 1 preceding)) as Total_Sum3
  from df_sql order by Client,Month
  """)
df_out.show()

which then gives the desired:

+------+------+-------+----------+
|Client| Month|Revenue|Total_Sum3|
+------+------+-------+----------+
|     A|201701|  100.0|      null|
|     A|201702|  101.0|     100.0|
|     A|201703|  102.0|     201.0|
|     A|201704|  103.0|     303.0|
|     A|201705|  104.0|     306.0|
|     B|201701|  201.0|      null|
|     B|201702|   null|     201.0|
|     B|201703|  203.0|     201.0|
|     B|201704|  204.0|     404.0|
|     B|201705|  205.0|     407.0|
|     B|201706|  206.0|     612.0|
|     B|201707|  207.0|     615.0|
+------+------+-------+----------+

Is sqlContext the best way to approach this or would it be better / more elegant to achieve the same result via pyspark.sql.window?

GivenX
  • 495
  • 1
  • 8
  • 17