12

Say I have a PairRDD as such (Obviously much more data in real life, assume millions of records):

val scores = sc.parallelize(Array(
      ("a", 1),  
      ("a", 2), 
      ("a", 3), 
      ("b", 3), 
      ("b", 1), 
      ("a", 4),  
      ("b", 4), 
      ("b", 2)
))

What is the most efficient way to generate a RDD with the top 2 scores per key?

val top2ByKey = ...
res3: Array[(String, Int)] = Array((a,4), (a,3), (b,4), (b,3))
michael_erasmus
  • 906
  • 1
  • 9
  • 17

4 Answers4

11

I think this should be quite efficient:

Edited according to OP comments:

scores.mapValues(p => (p, p)).reduceByKey((u, v) => {
  val values = List(u._1, u._2, v._1, v._2).sorted(Ordering[Int].reverse).distinct
  if (values.size > 1) (values(0), values(1))
  else (values(0), values(0))
}).collect().foreach(println)
abalcerek
  • 1,807
  • 1
  • 22
  • 27
  • This doesn't seem to work? This is the output: Array[(String, (Int, Int))] = Array((a,(4,4)), (b,(4,4))) – michael_erasmus May 11 '15 at 16:16
  • 1
    I got this to work by adapting user52045's answer: val scores = sc.parallelize(Array( ("a", 1), ("a", 2), ("a", 3), ("b", 3), ("b", 1), ("a", 4), ("b", 4), ("b", 2) )) scores.mapValues(p => (p, p)).reduceByKey((u, v) => { val values = List(u._1, u._2, v._1, v._2).sorted(Ordering[Int].reverse).distinct (values(0), values(1)) }).collect() – michael_erasmus May 11 '15 at 16:30
  • 1
    @michael_erasmus You are correct there is a bug in my code. Thx for fixing it. One thing you have to be careful because if all elements of the list are the same u will get outOfBoudException. – abalcerek May 11 '15 at 16:52
  • will this solution be efficient for large datasets? I mean: sorting everything to get only a few top elements seems like too exhausting – akademi4eg Nov 13 '15 at 13:55
  • @akademi4eg Notice you are sorting `inside` reduce function and only 4 element list. This is just a concise way of picking 2 maximal elements from at most 4 elements (to pairs). – abalcerek Nov 13 '15 at 14:32
  • This can only get top 2, but how can I get top n when n is an input for me? – Xiaoyu Chen Mar 01 '17 at 04:00
  • 1
    @XiaoyuChen You can get arbitrary number of elements with slight modification just. Replace first map with `mapValues(p => Range(0, n).map(_ => p))` and use n element sequence instead of two element tuple. – abalcerek Mar 01 '17 at 08:05
11

Since version 1.4, there is a built-in way to do this using MLLib: https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala

import org.apache.spark.mllib.rdd.MLPairRDDFunctions.fromPairRDD
scores.topByKey(2)
jbochi
  • 28,816
  • 16
  • 73
  • 90
2

Slightly modified your input data.

val scores = sc.parallelize(Array(
      ("a", 1),  
      ("a", 2), 
      ("a", 3), 
      ("b", 3), 
      ("b", 1), 
      ("a", 4),  
      ("b", 4), 
      ("b", 2),
      ("a", 6),
      ("b", 8)
    ))

I explain how to do it step by step:

1.Group by key to create array

scores.groupByKey().foreach(println)  

Result:

(b,CompactBuffer(3, 1, 4, 2, 8))
(a,CompactBuffer(1, 2, 3, 4, 6))

As you see, each value itself is a array of numbers. CompactBuffer is just optimised array.

2.For each key, reverse sort list of numbers that value contains

scores.groupByKey().map({ case (k, numbers) => k -> numbers.toList.sorted(Ordering[Int].reverse)} ).foreach(println)

Result:

(b,List(8, 4, 3, 2, 1))
(a,List(6, 4, 3, 2, 1))

3.Keep only first 2 elements from the 2nd step, they will be top 2 scores in the list

scores.groupByKey().map({ case (k, numbers) => k -> numbers.toList.sorted(Ordering[Int].reverse).take(2)} ).foreach(println)

Result:

(a,List(6, 4))
(b,List(8, 4))

4.Flat map to create new Paired RDD for each key and top score

scores.groupByKey().map({ case (k, numbers) => k -> numbers.toList.sorted(Ordering[Int].reverse).take(2)} ).flatMap({case (k, numbers) => numbers.map(k -> _)}).foreach(println)

Result:

(b,8)
(b,4)
(a,6)
(a,4)

5.Optional step - sort by key if you want

scores.groupByKey().map({ case (k, numbers) => k -> numbers.toList.sorted(Ordering[Int].reverse).take(2)} ).flatMap({case (k, numbers) => numbers.map(k -> _)}).sortByKey(false).foreach(println)

Result:

(a,6)
(a,4)
(b,8)
(b,4)

Hope, this explanation helped to understand the logic.

0
 scores.reduceByKey(_ + _).map(x => x._2 -> x._1).sortByKey(false).map(x => x._2 -> x._1).take(2).foreach(println)
Ning Guo
  • 1
  • 2
  • 5
    Hi, welcome to Stack Overflow. Please don't just dump code as your answer. Explain your train of thought so we can better understand it. Read this if you have any doubts: http://stackoverflow.com/help/how-to-answer Thanks. – Cthulhu Jan 05 '16 at 10:01
  • 1
    I believe scores.reduceByKey(_ + _) would collapse all pairs with the same key so you would end up with a single (a, N) and a single (b, M) where N and M are the sum of all a values and b values, respectively. At that point you only a single (a, N) and no amount of decomposition would get back (a, i) and (a, j) where i and j are the two highest values for all a pairs. – Thomas Nguyen Oct 13 '17 at 07:35