1

I understand that 'for' and 'while' loops are generally to-be-avoided when using Spark. My question is about optimizing a 'while' loop, though if I'm missing a solution that makes it unnecessary, I am all ears.

I'm not sure I can demonstrate the issue (very long processing times, compounding as the loop goes on) with toy data, but here is some pseudo code:

### I have a function - called 'enumerator' - which involves several joins and window functions. 
# I run this function on my base dataset, df0, and return df1
df1 = enumerator(df0, param1 = apple, param2 = banana)

# Check for some condition in df1, then count number of rows in the result
counter = df1 \
.filter(col('X') == some_condition) \
.count()

# If there are rows meeting this condition, start a while loop
while counter > 0:
  print('Starting with counter: ', str(counter))
  
  # Run the enumerator function on df1 again
  df2 = enumerator(df1, param1= apple, param2 = banana)
  
  # Check for the condition again, then continue the while loop if necessary
  counter = df2 \
  .filter(col('X') == some_condition) \
  .count()
  
  df1 = df2

# After the while loop finishes, I take the last resulting dataframe and I will do several more operations and analyses downstream  
final_df = df2

An essential aspect of the enumerator function is to 'look back' on a sequence in a window, and so it may take several runs before all the necessary corrections are made.

In my heart, I know this is ugly but the windowing/ranking/sequential analysis within the function is critical. My understanding is that the underlying Spark query plan gets more and more convoluted as the loop continues. Are there any best practices I should adopt in this situation? Should I be cacheing at any point - either before the while loop starts, or within the loop itself?

Alex Ott
  • 80,552
  • 8
  • 87
  • 132
mcharl02
  • 128
  • 1
  • 12

1 Answers1

1

You definitely should cache/persist the dataframes, otherwise every iteration in the while loop will start from scratch from df0. Also you may want to unpersist the used dataframes to free up disk/memory space.

Another point to optimize is not to do a count, but use a cheaper operation, such as df.take(1). If that returns nothing then counter == 0.

df1 = enumerator(df0, param1 = apple, param2 = banana)
df1.cache()

# Check for some condition in df1, then count number of rows in the result
counter = len(df1.filter(col('X') == some_condition).take(1))

while counter > 0:
  print('Starting with counter: ', str(counter))
  
  df2 = enumerator(df1, param1 = apple, param2 = banana)
  df2.cache()

  counter = len(df2.filter(col('X') == some_condition).take(1))
  df1.unpersist()    # unpersist df1 as it will be overwritten
  
  df1 = df2

final_df = df2
mck
  • 40,932
  • 13
  • 35
  • 50
  • thank you! using '.take(1)' for these scenarios something I hadn't considered before, but seems so obvious. I thought maybe the Spark query planner would end up doing the same thing under the hood if I was using a condition like 'counter > 0' (stopping as soon as it hit 1) – mcharl02 Dec 15 '20 at 21:38