1

I have an input dataframe as below:

partner_id|month_id|value1 |value2
1001      |  01    |10     |20    
1002      |  01    |20     |30    
1003      |  01    |30     |40
1001      |  02    |40     |50    
1002      |  02    |50     |60    
1003      |  02    |60     |70
1001      |  03    |70     |80    
1002      |  03    |80     |90    
1003      |  03    |90     |100

Using the below code, I have created two new columns which does the average using window function:

rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))
df = df.withColumn("value1_1", F.avg("value1").over(rnum))
df = df.withColumn("value1_2", F.avg("value2").over(rnum))

Output:

partner_id|month_id|value1 |value2|value1_1|value2_2
1001      |  01    |10     |20    |10      |20
1002      |  01    |20     |30    |20      |30
1003      |  01    |30     |40    |30      |40
1001      |  02    |40     |50    |25      |35
1002      |  02    |50     |60    |35      |45
1003      |  02    |60     |70    |45      |55
1001      |  03    |70     |80    |40      |50
1002      |  03    |80     |90    |50      |60
1003      |  03    |90     |100   |60      |70

The cumulative average is performing well on the value1 and value2 columns using pyspark Window function. But, if we miss one month data in the input like below, for the next month average calculation should happen based on month no. instead of normal average. For example, if the input is like below (month 02 data is missing)

partner_id|month_id|value1 |value2
1001      |  01    |10     |20    
1002      |  01    |20     |30    
1003      |  01    |30     |40
1001      |  03    |70     |80    
1002      |  03    |80     |90    
1003      |  03    |90     |100

Then the average calculation on month three records is happening as below: for ex: (70 + 10)/2 But, What is the correct way of doing average if certain month values are missing???

Rocky1989
  • 369
  • 8
  • 28

2 Answers2

2

if you are using spark 2.4+. you can use sequence function and array functions. This solution is inspired by this link

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window().partitionBy("partner_id")

df1 = (
    df.withColumn(
        "month_seq",
        F.sequence(F.min("month_id").over(w), F.max("month_id").over(w), F.lit(1)),
    )
    .groupBy("partner_id")
    .agg(
        F.collect_list("month_id").alias("month_id"),
        F.collect_list("value1").alias("value1"),
        F.collect_list("value2").alias("value2"),
        F.first("month_seq").alias("month_seq"),
    )
    .withColumn("month_seq", F.array_except("month_seq", "month_id"))
    .withColumn("month_id", F.flatten(F.array("month_id", "month_seq")))
    .drop("month_seq")
    .withColumn("zip", F.explode(F.arrays_zip("month_id", "value1", "value2")))
    .select(
        "partner_id",
        "zip.month_id",
        F.when(F.col("zip.value1").isNull(), F.lit(0))
        .otherwise(F.col("zip.value1"))
        .alias("value1"),
        F.when(F.col("zip.value2").isNull(), F.lit(0))
        .otherwise(F.col("zip.value2"))
        .alias("value2"),
    )
    .orderBy("month_id")
)

rnum = (
    Window.partitionBy("partner_id")
    .orderBy("month_id")
    .rangeBetween(Window.unboundedPreceding, 0)
)

df2 = df1.withColumn("value1_1", F.avg("value1").over(rnum)).withColumn(
    "value1_2", F.avg("value2").over(rnum)
)

Result :

df2.show()

# +----------+--------+------+------+------------------+------------------+
# |partner_id|month_id|value1|value2|          value1_1|          value1_2|
# +----------+--------+------+------+------------------+------------------+
# |      1002|       1|    10|    20|              10.0|              20.0|
# |      1002|       2|     0|     0|               5.0|              10.0|
# |      1002|       3|    80|    90|              30.0|36.666666666666664|
# |      1001|       1|    10|    10|              10.0|              10.0|
# |      1001|       2|     0|     0|               5.0|               5.0|
# |      1001|       3|    70|    80|26.666666666666668|              30.0|
# |      1003|       1|    30|    40|              30.0|              40.0|
# |      1003|       2|     0|     0|              15.0|              20.0|
# |      1003|       3|    90|   100|              40.0|46.666666666666664|
# +----------+--------+------+------+------------------+------------------+
Steven
  • 14,048
  • 6
  • 38
  • 73
kites
  • 1,375
  • 8
  • 15
  • if you don't want second month from the result set. you can drop those rows having value1 or value 2 = 0 – kites Jul 22 '20 at 21:12
1

Spark is not smart enough to understand that one month is missing, as it doesn't even know what a month is probably.

If you want the "missing" month to be included in the average computation, you need to generate the missing data.

Just performe a full outer join with a dataframe ["month_id", "defaultValue"] where month_id are values from 1 to 12 and defaultValue = 0.


Another solution, instead of performing an average, you perform a sum of values and you divide by your month number.

Steven
  • 14,048
  • 6
  • 38
  • 73