I have a Java program using Apache Spark. The most interesting part of the program looks like this:
long seed = System.nanoTime();
JavaRDD<AnnotatedDocument> annotated = documents
.mapPartitionsWithIndex(new InitialAnnotater(seed), true);
annotated.cache();
for (int iter = 0; iter < 2000; iter++) {
GlobalCounts counts = annotated
.mapPartitions(new GlobalCounter())
.reduce((a, b) -> a.sum(b)); // update overall counts (*)
seed = System.nanoTime();
// copy overall counts which CountChanger uses to compute a stochastic thing (**)
annotated = annotated
.mapPartitionsWithIndex(new CountChanger(counts, seed), true);
annotated.cache();
// adding these lines causes constant time complexity like i want
//List<AnnotatedDocument> ll = annotated.collect();
//annotated = sc.parallelize(ll, 8);
}
So in effect, the line (**) results in an RDD
with the form
documents
.mapPartitionsWithIndex(initial)
.mapPartitionsWithIndex(nextIter)
.mapPartitionsWithIndex(nextIter)
.mapPartitionsWithIndex(nextIter)
... 2000 more
a very long chain of maps indeed. In addition, line (*) forces computation (non-lazy) at each iteration as counts need to be updated.
The problem I have is that I get a time complexity that increases linearly with each iteration, and so quadratic overall:
I think this is because Spark tries to "remember" every RDD in the chain, and the fault tolerance algorithm or whatever is causing this to grow. However, I really have no idea.
What I'd really like to do is at each iteration tell Spark to "collapse" the RDD so that only the last one is kept in memory and worked on. This should result in constant time per iteration, I think. Is this possible? Are there any other solutions?
Thanks!