2

I'm stuck with a similar use case as in SPARK DataFrame: select the first row of each group. Only difference is that I need to select the first 3 rows of each group. The agg function allows me to select the top value with the help of max function or by applying sort initially and then using the first function.

Is there way to achieve it using agg function after doing a groupby? If not, what is the best way to do this?

ZygD
  • 22,092
  • 39
  • 79
  • 102
DhiwaTdG
  • 748
  • 1
  • 10
  • 26
  • 2
    Please illustrate your question with example dataset, attempted code and expected output. – mtoto Dec 22 '16 at 18:25

3 Answers3

9
import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

df=Dataframe....

val w = Window.partitionBy($"groupColumn").orderBy($"AnyColumn".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" ===> 3).drop("rn") 
dfTop.show
Qriss
  • 68
  • 8
Beyhan Gül
  • 171
  • 2
  • 11
2

Use window functions with row_number as in the linked question but replace:

.where($"rn" === 1)

with

.where($"rn" <= 3)
  • This is the functionality and code provided in Apache DataFu's [dedupTopN method](https://github.com/apache/datafu/blob/master/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala#L164) – Eyal Jul 26 '21 at 10:36
0

Solution is to iterate through the result list of values populated from groupByKey() and then extracting topN records and appending those values in a new list. Following is the working example, you can execute it on Cloudera VM as I have used Cloudera sample data set. Before executing it, make sure that you have product RDD generated from products table which exists in mySql - retail_db database.

getTopN function ->

def getTopN(rec: (String, Iterable[String]), topN: Int): Iterable[String] = {
          var prodPrices: List[Float] = List()
          var topNPrices: List[Float] = List()
          var sortedRecs: List[String] = List()
          for(i <- rec._2) {
            prodPrices = prodPrices:+ i.split(",")(4).toFloat
          }
          topNPrices = prodPrices.distinct.sortBy(k => -k).take(topN)
          sortedRecs = rec._2.toList.sortBy(k => -k.split(",")(4).toFloat) 
          var x: List[String] = List()
          for(i <- sortedRecs) {
            if(topNPrices.contains(i.split(",")(4).toFloat))
              x = x:+ i 
          }
          return x
        }

Main code ->

##code to generate products RDD

val productsMap = products.
  map(rec => (rec.split(",")(1), rec))
productsMap.
  groupByKey().
  flatMap(x => getTopN(x, 3)).
  collect().
  foreach(println)
Nikhil Bhide
  • 728
  • 8
  • 23