I read a lot of articles, blog and stackoverflow posts but still can't wrap my head around how spark will cache the datasets in my specific use case involving lots of transformations but only few read and save actions. Here's my use case in pseudo-code
val ds1 = spark.loadFromDatabase("table_1") // Action (1)
val ds2 = spark.loadFromDatabase("table_2") // Action (2)
val ds3 = spark.loadFromDatabase("table_3") // Action (3)
val intermediateDs1 = transform(ds3)
val intermediateDs2 = transform(ds1, intermediateDs1)
val intermediateDs3 = transform(ds2, intermediateDs1, intermediateDs2)
val intermediateResultDs1 = transform(intermediateDs2)
val intermediateResultDs2 = transform(intermediateDs3)
val finalResult1 = transform(intermediateResultDs1)
val finalResult2 = transform(intermediateResultDs2)
spark.writeToDatabase(finalResult1, "table_1") // Action (4)
spark.writeToDatabase(finalResult2, "table_2") // Action (5)
I want to achieve two things:
- Prevent spark from loading the data from the tables more than once for performance reasons, but also because the actions will replace the table contents and therefore will lead to unexpected behavior while executing Action (5)
- Prevent spark from executing some of the transformations multiple times for performance reasons (e.g. intermediateDs2 and intermediateDs3 depend on intermediateDs1).
So I experimented with cache() and unpersist() but I'm quite unsure on how to optimize the execution. First I thought it would be a good idea to cache the datasets which are used multiple times and unpersist them when they are not needed anymore to free up memory space.
val ds1 = spark.loadFromDatabase("table_1")
val ds2 = spark.loadFromDatabase("table_2")
val ds3 = spark.loadFromDatabase("table_3")
val intermediateDs1 = transform(ds3).cache()
val intermediateDs2 = transform(ds1, intermediateDs1).cache()
val intermediateDs3 = transform(ds2, intermediateDs1, intermediateDs2)
val intermediateResultDs1 = transform(intermediateDs2)
val intermediateResultDs2 = transform(intermediateDs3)
intermediateDs2.unpersist() // not needed anymore
intermediateDs1.unpersist() // not needed anymore
val finalResult1 = transform(intermediateResultDs1)
val finalResult2 = transform(intermediateResultDs2)
spark.writeToDatabase(finalResult1, "table_1")
spark.writeToDatabase(finalResult2, "table_2")
But I get the feeling that my assumptions regarding unpersist() is wrong, see Understanding Spark's caching
Which datasets should be cached AND unpersisted in which order in that specific scenario to achieve these goals?
Thanks!