20

I need to split an RDD into 2 parts:

1 part which satisfies a condition; another part which does not. I can do filter twice on the original RDD but it seems inefficient. Is there a way that can do what I'm after? I can't find anything in the API nor in the literature.

monster
  • 1,762
  • 3
  • 20
  • 38

5 Answers5

21

Spark doesn't support this by default. Filtering on the same data twice isn't that bad if you cache it beforehand, and the filtering itself is quick.

If it's really just two different types, you can use a helper method:

implicit class RDDOps[T](rdd: RDD[T]) {
  def partitionBy(f: T => Boolean): (RDD[T], RDD[T]) = {
    val passes = rdd.filter(f)
    val fails = rdd.filter(e => !f(e)) // Spark doesn't have filterNot
    (passes, fails)
  }
}

val (matches, matchesNot) = sc.parallelize(1 to 100).cache().partitionBy(_ % 2 == 0)

But as soon as you have multiple types of data, just assign the filtered to a new val.

Alberto Bonsanto
  • 17,556
  • 10
  • 64
  • 93
Marius Soutier
  • 11,184
  • 1
  • 38
  • 48
5

Spark RDD does not have such api.

Here is a version based on a pull request for rdd.span that should work:

import scala.reflect.ClassTag
import org.apache.spark.rdd._

def split[T:ClassTag](rdd: RDD[T], p: T => Boolean): (RDD[T], RDD[T]) = {

    val splits = rdd.mapPartitions { iter =>
        val (left, right) = iter.partition(p)
        val iterSeq = Seq(left, right)
        iterSeq.iterator
    }

    val left = splits.mapPartitions { iter => iter.next().toIterator}

    val right = splits.mapPartitions { iter => 
        iter.next()
        iter.next().toIterator
    }
    (left, right)
}

val rdd = sc.parallelize(0 to 10, 2)

val (first, second) = split[Int](rdd, _ % 2 == 0 )

first.collect
// Array[Int] = Array(0, 2, 4, 6, 8, 10)
Shyamendra Solanki
  • 8,751
  • 2
  • 31
  • 25
  • 1
    I would wager this is more complex and less efficient than two filters – Justin Pihony Apr 10 '15 at 02:00
  • @JustinPihony yes, filters are much more efficient. – Shyamendra Solanki Apr 10 '15 at 06:00
  • 1
    This approach will cause the rdd to be evaluated twice (unless it was cached beforehand). Since this gives no benefit over the standard "two filters", I think it's not worth using it. Even the built-in `randomSplit(...)` results in multiple evaluations of the given rdd. There doesn't seem to be a way (at least none that I found yet) to create a 1-pass split method that returns two RDDs. – borice Aug 07 '16 at 23:38
4

The point is, you do not want to do a filter, but a map.

(T) -> (Boolean, T)

Sorry, I am inefficient in Scala Syntax. But the idea is that you split your answer set by mapping it to Key/Value pairs. The Key can be a boolean indicating wether or not it was passing the 'Filter' predicate.

You can control output to different targets by doing partition wise processing. Just make sure that you don’t restrict parallel processing to just two partitions downstream.

See also How do I split an RDD into two or more RDDs?

YoYo
  • 9,157
  • 8
  • 57
  • 74
1

If you are ok with a T instead of an RDD[T], then you can do this. Otherwise, you could maybe do something like this:

val data = sc.parallelize(1 to 100)
val splitData = data.mapPartitions{iter => {
    val splitList = (iter.toList).partition(_%2 == 0)
    Tuple1(splitList).productIterator
  }
}.map(_.asInstanceOf[Tuple2[List[Int],List[Int]]])

And, then you will probably need to reduce this down to merge the lists when you go to perform an action

Community
  • 1
  • 1
Justin Pihony
  • 66,056
  • 18
  • 147
  • 180
  • I would love to find out why this was down voted as it is the only answer that actually answers the OPs question – Justin Pihony Apr 10 '15 at 02:02
  • (note: I did not down vote) your method is interesting but it doesn't answer the question. The OP asked for a `partition(RDD[A], A => Boolean): (RDD[A], RDD[A])`, yours would be `partition(RDD[A], A => Boolean): RDD[List[A], List[A]]` – Juh_ Jun 10 '16 at 11:45
0

You can use subtract function (If filter operation is too expensive).

PySpark code:

rdd1 = data.filter(filterFunction)

rdd2 = data.subtract(rdd1)
Phani Kumar M
  • 4,564
  • 1
  • 14
  • 26