My Question: Why does Spark calculate sum
and count
from each partition, do an unnecessary (IMHO) shuffle (Exchange hashpartitioning
), and then calculate the mean in HashAggregate
?
What could've been done: Calculate the mean for each partition and then combine (union) the results.
Details:
I am reading data from Hive table defined below, which is partitioned by date.
spark.sql("""Create External Table If Not Exists daily_temp.daily_temp_2014
(
state_name string,
...
) Partitioned By (
date_local string
)
Location "./daily_temp/"
Stored As ORC""")
It consists of daily measurement of temperature for various points in the US downloaded from EPA Website.
Using the code below, data is loaded from Hive table into PySpark DataFrame:
spark = (
SparkSession.builder
.master("local")
.appName("Hive Partition Test")
.enableHiveSupport()
.config("hive.exec.dynamic.partition", "true")
.config("hive.exec.dynamic.partition.mode", "nonstrict")
.getOrCreate()
)
my_df = spark.sql("select * from daily_temp.daily_temp_2014")
I would like to calculate daily mean temperature per state.
daily_state_mean = (
my_df
.groupBy(
my_df.date_local,
my_df.state_name
)
.agg({"arithmetic_mean":"mean"})
)
And this is part of the physical (execution) plan:
+- *(2) HashAggregate(keys=[date_local#3003, state_name#2998], functions=[avg(cast(arithmetic_mean#2990 as double))], output=[date_local#3003, state_name#2998, avg(CAST(arithmetic_mean AS DOUBLE))#3014])
+- Exchange hashpartitioning(date_local#3003, state_name#2998, 365)
+- *(1) HashAggregate(keys=[date_local#3003, state_name#2998], functions=[partial_avg(cast(arithmetic_mean#2990 as double))], output=[date_local#3003, state_name#2998, sum#3021, count#3022L])
+- HiveTableScan [arithmetic_mean#2990, state_name#2998, date_local#3003], HiveTableRelation `daily_temp`.`daily_temp_2014`, org.apache.hadoop.hive.ql.io.orc.OrcSerde, [...], [date_local#3003]
Your advice and insights are highly appreciated.