Initial Theory
You are hitting a known limitation of Spark, similar to the findings discussed over here.
However, there are ways to work around this by re-thinking your implementation to instead be a series of dispatched instructions describing the batches of data you wish to operate on, similar to how you create your tmp
DataFrame.
This may unfortunately require quite a bit more work to re-think your logic in this way since you'll want to imagine your manipulations purely as a series of column manipulation commands given to PySpark instead of row-by-row manipulations. There are some operations you cannot do purely using PySpark calls, so this isn't always possible. In general it's worth thinking through very carefully.
Concretely
As an example, your data range calculations are possible to perform purely in PySpark and will be substantially faster if you do this operations over many years or other increased scale. Instead of using Python list comprehension or other logic, we instead use column manipulations on a small set of initial data to build up our ranges.
I've written up some example code here on how you can create your date batches, this should let you perform a join
to create your tmp
DataFrame, after which you can describe the types of operations you wish to do to it.
Code to create date ranges (start and end dates of each week of the year):
from pyspark.sql import types as T, functions as F, SparkSession, Window
from datetime import date
spark = SparkSession.builder.getOrCreate()
year_marker_schema = T.StructType([
T.StructField("max_year", T.IntegerType(), False),
])
year_marker_data = [
{"max_year": 2022}
]
year_marker_df = spark.createDataFrame(year_marker_data, year_marker_schema)
year_marker_df.show()
"""
+--------+
|max_year|
+--------+
| 2022|
+--------+
"""
previous_week_window = Window.partitionBy(F.col("start_year")).orderBy("start_week_index")
year_marker_df = year_marker_df.select(
(F.col("max_year") - 1).alias("start_year"),
"*"
).select(
F.to_date(F.col("max_year").cast(T.StringType()), "yyyy").alias("max_year_date"),
F.to_date(F.col("start_year").cast(T.StringType()), "yyyy").alias("start_year_date"),
"*"
).select(
F.datediff(F.col("max_year_date"), F.col("start_year_date")).alias("days_between"),
"*"
).select(
F.floor(F.col("days_between") / 7).alias("weeks_between"),
"*"
).select(
F.sequence(F.lit(0), F.col("weeks_between")).alias("week_indices"),
"*"
).select(
F.explode(F.col("week_indices")).alias("start_week_index"),
"*"
).select(
F.lead(F.col("start_week_index"), 1).over(previous_week_window).alias("end_week_index"),
"*"
).select(
((F.col("start_week_index") * 7) + 1).alias("start_day"),
((F.col("end_week_index") * 7) + 1).alias("end_day"),
"*"
).select(
F.concat_ws(
"-",
F.col("start_year"),
F.col("start_day").cast(T.StringType())
).alias("start_day_string"),
F.concat_ws(
"-",
F.col("start_year"),
F.col("end_day").cast(T.StringType())
).alias("end_day_string"),
"*"
).select(
F.to_date(
F.col("start_day_string"),
"yyyy-D"
).alias("start_date"),
F.to_date(
F.col("end_day_string"),
"yyyy-D"
).alias("end_date"),
"*"
)
year_marker_df.drop(
"max_year",
"start_year",
"weeks_between",
"days_between",
"week_indices",
"max_year_date",
"start_day_string",
"end_day_string",
"start_day",
"end_day",
"start_week_index",
"end_week_index",
"start_year_date"
).show()
"""
+----------+----------+
|start_date| end_date|
+----------+----------+
|2021-01-01|2021-01-08|
|2021-01-08|2021-01-15|
|2021-01-15|2021-01-22|
|2021-01-22|2021-01-29|
|2021-01-29|2021-02-05|
|2021-02-05|2021-02-12|
|2021-02-12|2021-02-19|
|2021-02-19|2021-02-26|
|2021-02-26|2021-03-05|
|2021-03-05|2021-03-12|
|2021-03-12|2021-03-19|
|2021-03-19|2021-03-26|
|2021-03-26|2021-04-02|
|2021-04-02|2021-04-09|
|2021-04-09|2021-04-16|
|2021-04-16|2021-04-23|
|2021-04-23|2021-04-30|
|2021-04-30|2021-05-07|
|2021-05-07|2021-05-14|
|2021-05-14|2021-05-21|
+----------+----------+
only showing top 20 rows
"""
Potential Optimizations
Once you have this code and if you are unable to express your work through joins / column derivations alone and are forced to perform the operation with the union_many
, you may consider using Spark's localCheckpoint feature on your df2
result. This will allow Spark to simply calculate the resultant DataFrame and not add its query plan onto the result you will push to your df_total
. This could be paired with the cache to also keep the resultant DataFrame in memory, but this will depend on your data scale.
localCheckpoint
and cache
are useful to avoid re-computing the same DataFrames many times over and truncating the amount of query planning done on top of your intermediate DataFrames.
You'll likely find localCheckpoint
and cache
can be useful on your df
DataFrame as well since it will be used many times over in your loop (assuming you are unable to re-work your logic to use SQL-based operations and instead are still forced to use the loop).
As a quick and dirty summary of when to use each:
Use localCheckpoint
on a DataFrame that was complex to compute and is going to be used in operations later. Oftentimes these are the nodes feeding into union
s
Use cache
on a DataFrame that is going to be used many times later. This often is a DataFrame sitting outside of a for/while loop that will be called in the loop
All Together
Your initial code
Date_list = [All weeks from: '2021-01-01', to: '2022-01-01'] --> ~50 elements
df_total = spark.createDataframe([], schema)
df_date = []
for date in Date_list:
tmp = df.filter(between [date, date-7days]).withColumn('example', F.lit(date))
........
df2 = df.join(tmp, 'column', 'inner').......
df_date += [df2]
df_total = df_total.unionByName(union_many(*df_date))
return df_total
Should now look like:
# year_marker_df as derived in my code above
year_marker_df = year_marker_df.cache()
df = df.join(year_marker_df, df.my_date_column between year_marker_df.start_date, year_marker_df.end_date)
# Other work previously in your for_loop, resulting in df_total
return df_total
Or, if you are unable to re-work your inner loop operations, you can do some optimizations like:
Date_list = [All weeks from: '2021-01-01', to: '2022-01-01'] --> ~50 elements
df_total = spark.createDataframe([], schema)
df_date = []
df = df.cache()
for date in Date_list:
tmp = df.filter(between [date, date-7days]).withColumn('example', F.lit(date))
........
df2 = df.join(tmp, 'column', 'inner').......
df2 = df2.localCheckpoint()
df_date += [df2]
df_total = df_total.unionByName(union_many(*df_date))
return df_total