0

I have the input data set like:

id     operation          value
1      null                1
1      discard             0
2      null                1
2      null                2
2      max                 0
3      null                1
3      null                1
3      list                0

I want to group the input and produce rows according to "operation" column.

for group 1, operation="discard", then the output is null,

for group 2, operation="max", the output is:

2      null                2

for group 3, operation="list", the output is:

3      null                1
3      null                1

So finally the output is like:

  id     operation          value
   2      null                2
   3      null                1
   3      null                1

Is there a solution for this?

I know there is a similar question how-to-iterate-grouped-data-in-spark But the differences compared to that are:

    1. I want to produce more than one row for each grouped data. Possible and how?
    2. I want my logic to be easily extended for more operation to be added in future. So User-defined aggregate functions (aka UDAF) is the only possible solution?

Update 1:

Thank stack0114106, then more details according to his answer, e.g. for id=1, operation="max", I want to iterate all the item with id=2, and find the max value, rather than assign a hard-coded value, that's why I want to iterate the rows in each group. Below is a updated example:

The input:

scala> val df = Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id"
,"operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]

scala> df.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|0  |null     |1    |
|0  |discard  |0    |
|1  |null     |1    |
|1  |null     |2    |
|1  |max      |0    |
|2  |null     |1    |
|2  |null     |3    |
|2  |max      |0    |
|3  |null     |1    |
|3  |null     |1    |
|3  |list     |0    |
+---+---------+-----+

The expected output:

+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1  |null     |2    |
|2  |null     |3    |
|3  |null     |1    |
|3  |null     |1    |
+---+---------+-----+
Michael
  • 103
  • 2
  • 11

2 Answers2

1

group everything collecting the values, then write logic for each operation :

import org.apache.spark.sql.functions._
val grouped=df.groupBy($"id").agg(max($"operation").as("op"),collect_list($"value").as("vals"))
val maxs=grouped.filter($"op"==="max").withColumn("val",explode($"vals")).groupBy($"id").agg(max("val").as("value"))
val lists=grouped.filter($"op"==="list").withColumn("value",explode($"vals")).filter($"value"!==0).select($"id",$"value")
//we don't collect the "discard"
//and we can add additional subsets for new "operations"
val result=maxs.union(lists)
//if you need the null in "operation" column add it with withColumn
Arnon Rotem-Gal-Oz
  • 25,469
  • 3
  • 45
  • 68
0

You can use flatMap operation on the dataframe and generate required rows based on the conditions that you mentioned. Check this out

scala> val df = Seq((1,null,1),(1,"discard",0),(2,null,1),(2,null,2),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]

scala> df.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1  |null     |1    |
|1  |discard  |0    |
|2  |null     |1    |
|2  |null     |2    |
|2  |max      |0    |
|3  |null     |1    |
|3  |null     |1    |
|3  |list     |0    |
+---+---------+-----+


scala> df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0
 until s._1).map( i => (r.getInt(0),null,s._2) ) }).show(false)
+---+----+---+
|_1 |_2  |_3 |
+---+----+---+
|2  |null|2  |
|3  |null|1  |
|3  |null|1  |
+---+----+---+

Spark assigns _1,_2 etc.. so you can map them to actual names by assigning them as below

scala> val df2 = df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0 until s._1).map( i => (r.getInt(0),null,s._2) ) }).toDF("id","operation","value")
df2: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]

scala> df2.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|2  |null     |2    |
|3  |null     |1    |
|3  |null     |1    |
+---+---------+-----+


scala>

EDIT1:

Since you need the max(value) for each id, you can use window functions and get the max value in a new column, then use the same technique and get the results. Check this out

scala> val df =   Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]

scala> df.createOrReplaceTempView("michael")

scala> val df2 = spark.sql(""" select *, max(value) over(partition by id) mx from michael """)
df2: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 2 more fields]

scala> df2.show(false)
+---+---------+-----+---+
|id |operation|value|mx |
+---+---------+-----+---+
|1  |null     |1    |2  |
|1  |null     |2    |2  |
|1  |max      |0    |2  |
|3  |null     |1    |1  |
|3  |null     |1    |1  |
|3  |list     |0    |1  |
|2  |null     |1    |3  |
|2  |null     |3    |3  |
|2  |max      |0    |3  |
|0  |null     |1    |1  |
|0  |discard  |0    |1  |
+---+---------+-----+---+


scala> val df3 = df2.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => 0 case "max" => 1 case "list" => 2 } ; (0 until s).map( i => (r.getInt(0),null,r.getInt(3) )) }).toDF("id","operation","value")
df3: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]


scala> df3.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1  |null     |2    |
|3  |null     |1    |
|3  |null     |1    |
|2  |null     |3    |
+---+---------+-----+


scala>
stack0114106
  • 8,534
  • 3
  • 13
  • 38
  • Thanks for your quick reply. But still some issues: e.g. for operation="max", I want to iterate all the item with id=2, and find the max value, rather than assign a hard-coded value "2", that's why I want to iterate the rows in each group. Please see "Update 1" in my question. – Michael Dec 31 '18 at 22:22
  • Please check my ````EDIT1```` – stack0114106 Dec 31 '18 at 23:23
  • Thanks stack0114106! This works for me although I need more time to digest. – Michael Jan 01 '19 at 00:29
  • glad that it works for you.. using window functions you can get the results of aggregate values in the "columns" as opposed to how you get at the bottom in traditional group-by query.. the window functions has it's own syntax..so it might take sometime for you to become comfortable.. – stack0114106 Jan 01 '19 at 00:33