I am writing a PySpark implementation of an algorithm that is iterative in nature. Part of the algorithm involves iterating a strategy until no more improvements can be made (i.e., a local maximum has been greedily reached).
The function optimize
returns a three-column dataframe that looks as follows:
id | current_value | best_value |
---|---|---|
0 | 1 | 1 |
1 | 0 | 1 |
This function is used in a while loop until current_value
and best_value
are identical (meaning that no more optimizations can be made).
# Init while loop
iterate = True
# Start iterating until optimization yields same result as before
while iterate:
# Create (or overwrite) `df`
df = optimizeAll(df2) # Uses `df2` as input
df.persist().count()
# Check stopping condition
iterate = df.where('current_value != best_value').count() > 0
# Update `df2` with latest results
if iterate:
df2 = df2.join(other=df, on='id', how='left') # <- Should I persist this?
This function runs very quickly when I pass it the inputs manually. However, I have noticed that the time it takes for the function to run increases exponentially as it iterates. That is, the first iteration runs in milliseconds, the second one in seconds and eventually it takes up to 10 minutes per pass.
This question suggests that if df
isn't cached, the while loop will start running from scratch on every iteration. Is this true?
If so, which objects should I persist? I know that persisting df
will be triggered by the count
when defining iterate
. However, df2
has no action, so even if I persist it, will it make the while loop start from scratch every time? Likewise, should I unpersist either table at some point in the loop?