Not sure how your input data looks like but let's say we have a dataframe that looks like this:
+---------+-----+-----+
|condition|model|price|
+---------+-----+-----+
|A |A |1 |
|A |B |2 |
|A |B |2 |
|A |A |1 |
|A |A |null |
|B |A |3 |
|B |A |null |
|B |B |4 |
+---------+-----+-----+
We want to fill null with average but over condition
and model
.
For this we can define a Window
, calculate avg
and then replace null
.
Example:
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
spark = SparkSession.builder.appName("test").getOrCreate()
data = [
{"condition": "A", "model": "A", "price": 1},
{"condition": "A", "model": "B", "price": 2},
{"condition": "A", "model": "B", "price": 2},
{"condition": "A", "model": "A", "price": 1},
{"condition": "A", "model": "A", "price": None},
{"condition": "B", "model": "A", "price": 3},
{"condition": "B", "model": "A", "price": None},
{"condition": "B", "model": "B", "price": 4},
]
window = Window.partitionBy(["condition", "model"]).orderBy("condition")
df = spark.createDataFrame(data=data)
df = (
df.withColumn("avg", F.avg("price").over(window))
.withColumn(
"price", F.when(F.col("price").isNull(), F.col("avg")).otherwise(F.col("price"))
)
.drop("avg")
)
Which gives us:
+---------+-----+-----+
|condition|model|price|
+---------+-----+-----+
|A |A |1.0 |
|A |A |1.0 |
|A |A |1.0 |
|B |B |4.0 |
|B |A |3.0 |
|B |A |3.0 |
|A |B |2.0 |
|A |B |2.0 |
+---------+-----+-----+