6

I have a Java program using Apache Spark. The most interesting part of the program looks like this:

long seed = System.nanoTime();

JavaRDD<AnnotatedDocument> annotated = documents
    .mapPartitionsWithIndex(new InitialAnnotater(seed), true);
annotated.cache();

for (int iter = 0; iter < 2000; iter++) {
    GlobalCounts counts = annotated
        .mapPartitions(new GlobalCounter())
        .reduce((a, b) -> a.sum(b)); // update overall counts (*)

    seed = System.nanoTime();

    // copy overall counts which CountChanger uses to compute a stochastic thing (**)
    annotated = annotated
        .mapPartitionsWithIndex(new CountChanger(counts, seed),  true); 
    annotated.cache();

    // adding these lines causes constant time complexity like i want
    //List<AnnotatedDocument> ll = annotated.collect();
    //annotated = sc.parallelize(ll, 8); 
}

So in effect, the line (**) results in an RDD with the form

documents
    .mapPartitionsWithIndex(initial)
    .mapPartitionsWithIndex(nextIter)
    .mapPartitionsWithIndex(nextIter)
    .mapPartitionsWithIndex(nextIter)
    ... 2000 more

a very long chain of maps indeed. In addition, line (*) forces computation (non-lazy) at each iteration as counts need to be updated.

The problem I have is that I get a time complexity that increases linearly with each iteration, and so quadratic overall:

enter image description here

I think this is because Spark tries to "remember" every RDD in the chain, and the fault tolerance algorithm or whatever is causing this to grow. However, I really have no idea.

What I'd really like to do is at each iteration tell Spark to "collapse" the RDD so that only the last one is kept in memory and worked on. This should result in constant time per iteration, I think. Is this possible? Are there any other solutions?

Thanks!

bombax
  • 1,189
  • 8
  • 26
  • Is there any reason you're caching the RDD for every iteration? Instead of caching the last accumulated RDD at the end of the loop? – Yuval Itzchakov Mar 21 '16 at 10:02
  • I'm still experimenting with the effects of caching so my answer would have to be "not really." – bombax Mar 21 '16 at 10:04
  • Are you actually reusing the RDD each computation? Or is it a fresh RDD each time you want to calculate the counters? – Yuval Itzchakov Mar 21 '16 at 10:06
  • 1
    It's more like, it's a single initial RDD that goes through 2000 changes. I think Spark is trying to remember the chain each iterations whereas I'd like it to think, in each iteration, the RDD is "fresh". I edited my code to clarify. – bombax Mar 21 '16 at 10:08
  • 3
    Try not caching the RDD, as you're effectively iterating over a different RDD each time. Also, I'd suggest you look at the Spark UI to see what is taking so long. Perhaps your job is causing GC pressure, but the UI should definitely give you more insights into what's going on. – Yuval Itzchakov Mar 21 '16 at 10:22
  • Thanks for the suggestion Yuval. Unfortunately I get the same quadratic effect with caching removed completely. I have yet to find a method that accomplishes the same as the inactive lines above. – bombax Mar 21 '16 at 10:27

3 Answers3

7

Try using rdd.checkpoint. This will save RDD to hdfs and clear lineage.

Each time you transform an RDD you grow the lineage and Spark has to track what is available and what has to be re-computed. Processing the DAG is expensive and large DAGs tend to kill performance quite quickly. By "checkpointing" you instruct Spark to compute and save resulting RDD and discard the information of how it got created. This makes it similar to simply saving an RDD and reading it back which minimizes DAG operation.

On a sidenote, since you hit this issue, it is good to know that union also impacts RDD performance by adding steps and could also throw a StackOverflowError due to the way lineage information is . See this post

This link has more details with nice diagrams and the subject is also mentioned in this SO post.

Community
  • 1
  • 1
Ioannis Deligiannis
  • 2,679
  • 5
  • 25
  • 48
4

That's a really interesting question and there are a few things to consider.

Fundamentally this is an iterative algorithm, if you look at some of the different iterative machine learning algorithms in Spark you can see some approaches to working with this kind of problem.

The first thing that most of them don't cache on each iteration - rather they have a configurable caching interval. I'd probably start by caching every 10 iterations and seeing how that goes.

The other issue becomes the lineage graph, each mapPartitions you do is growing the graph a little more. At some point keeping track of that data is going to start to become more and more expensive. checkpoint allows you to have Spark write the current RDD to persistant storage and discard the lineage information. You could try doing this at some interval like every 20 iterations and seeing how this goes.

The 10 and 20 numbers are just sort of basic starting points, they depend on how slow it is to compute the data for each individual iteration and you can play with them to find the right tuning for your job.

Holden
  • 7,392
  • 1
  • 27
  • 33
2
  • try to materialize your rdd before caching with annotated.count() every few(needs tuning) iterations.
  • it's better to control where the rdd is cached with persist(...) instead of cache() which puts rdd in memory, persist permits you to choose where it goes(depends on your memory availability)
  • it's better to "save" cached/persisted rdd and then unpersist it after caching/persisting next cycle. Spark does it by itself, but if you controlling it, spark won't need to choose which rdd to throw from the cache
Igor Berman
  • 1,522
  • 10
  • 16