0

I have a data frame like below.

    scala> ds.show
    +----+----------+----------+-----+
    | key|attribute1|attribute2|value|
    +----+----------+----------+-----+
    |mac1|        A1|        B1|   10|
    |mac2|        A2|        B1|   10|
    |mac3|        A2|        B1|   10|
    |mac1|        A1|        B2|   10|
    |mac1|        A1|        B2|   10|
    |mac3|        A1|        B1|   10|
    |mac2|        A2|        B1|   10|
    +----+----------+----------+-----+

For each value in attribute1, I want to find the top N keys and the aggregated value for that key. Output: aggregated value for key for attribute1 will be

    +----+----------+-----+
    | key|attribute1|value|
    +----+----------+-----+
    |mac1|        A1|   30|
    |mac2|        A2|   20|
    |mac3|        A2|   10|
    |mac3|        A1|   10|
    +----+----------+-----+

Now if N = 1 then the output will be A1 - (mac1,30) A2-(mac2,20)

How to achieve this in DataFrame/Dataset ? I want to achieve this for all the attributes. In the above example I want to find for attribute1 and attribute2 as well.

zero323
  • 322,348
  • 103
  • 959
  • 935
Knight71
  • 2,927
  • 5
  • 37
  • 63

1 Answers1

1

Given the input dataframe as

+----+----------+----------+-----+
|key |attribute1|attribute2|value|
+----+----------+----------+-----+
|mac1|A1        |B1        |10   |
|mac2|A2        |B1        |10   |
|mac3|A2        |B1        |10   |
|mac1|A1        |B2        |10   |
|mac1|A1        |B2        |10   |
|mac3|A1        |B1        |10   |
|mac2|A2        |B1        |10   |
+----+----------+----------+-----+

and doing aggregation on the above input dataframe as

import org.apache.spark.sql.functions._
val groupeddf = df.groupBy("key", "attribute1").agg(sum("value").as("value"))

should give you

+----+----------+-----+
|key |attribute1|value|
+----+----------+-----+
|mac1|A1        |30.0 |
|mac3|A1        |10.0 |
|mac3|A2        |10.0 |
|mac2|A2        |20.0 |
+----+----------+-----+

now you can use Window function to generate ranks for each row in grouped data and filter rows with rank <= N as

val N = 1

val windowSpec = Window.partitionBy("attribute1").orderBy($"value".desc)

groupeddf.withColumn("rank", rank().over(windowSpec))
  .filter($"rank" <= N)
  .drop("rank")

which should give you the dataframe you desire.

+----+----------+-----+
|key |attribute1|value|
+----+----------+-----+
|mac2|A2        |20.0 |
|mac1|A1        |30.0 |
+----+----------+-----+
Ramesh Maharjan
  • 41,071
  • 6
  • 69
  • 97