Recently I've encountered an issue running one of our PySpark jobs. While analyzing the stages in Spark UI I have noticed that the longest running stage takes 1.2 hours to run out of the total 2.5 hours that takes for the entire process to run.
Once I took a look at the stage details it was clear that I'm facing a severe data skew, causing a single task to run for the entire 1.2 hours while all other tasks finish within 23 seconds.
The DAG showed this stage involves Window Functions which helped me to quickly narrow down the problematic area to a few queries and finding the root cause -> The column, account
, that was being used in the Window.partitionBy("account")
had 25% of null values.
I don't have an interest to calculate the sum for the null accounts though I do need the involved rows for further calculations therefore I can't filter them out prior the window function.
Here is my window function query:
problematic_account_window = Window.partitionBy("account")
sales_with_account_total_df = sales_df.withColumn("sum_sales_per_account", sum(col("price")).over(problematic_account_window))
So we found the one to blame - What can we do now? How can we resolve the skew and the performance issue?