0

Given a table like the following:

+--+------------------+-----------+
|id|     diagnosis_age|  diagnosis|
+--+------------------+-----------+
| 1|2.1843037179180302| 315.320000|
| 1|  2.80033330216659| 315.320000|
| 1|   2.8222365762732| 315.320000|
| 1|  5.64822705794013| 325.320000|
| 1| 5.686557787521759| 335.320000|
| 2|  5.70572315231258| 315.320000|
| 2| 5.724888517103389| 315.320000|
| 3| 5.744053881894209| 315.320000|
| 3|5.7604813374292005| 315.320000|
| 3|  5.77993740687426| 315.320000|
+--+------------------+-----------+

I'm trying to reduce the records per id to just one by taking the most frequent diagnosis for that id.

If it were an rdd, something like would do it:

rdd.map(lambda x: (x["id"], [(x["diagnosis_age"], x["diagnosis"])]))\
.reduceByKey(lambda x, y: x + y)\
.map(lambda x: [i[1] for i in x[1]])\
.map(lambda x: [max(zip((x.count(i) for i in set(x)), set(x)))])

in sql:

select id, diagnosis, diagnosis_age
from (select id, diagnosis, diagnosis_age, count(*) as cnt,
             row_number() over (partition by id order by count(*) desc) as seqnum
      from t
      group by id, diagnosis, age
     ) da
where seqnum = 1;

desired output:

+--+------------------+-----------+
|id|     diagnosis_age|  diagnosis|
+--+------------------+-----------+
| 1|2.1843037179180302| 315.320000|
| 2|  5.70572315231258| 315.320000|
| 3| 5.744053881894209| 315.320000|
+--+------------------+-----------+

How can I achieve the same using only spark dataframe operations if possible? Specifically without using any rdd actions/ sql.

Thanks

mad-a
  • 153
  • 3
  • 11
  • correct me if i am wrong, you want the least value of diagnosis age per id and most frequent diagnosis age per id? – murtihash Mar 25 '20 at 14:42
  • @Mohammad Murtaza Hashmi I just want the most frequent diagnosis per id, regardless of diagnosis age, I just assumed that in the example table that would also return the least diagnosis age record. – mad-a Mar 25 '20 at 15:12
  • 1
    Does this answer your question? [How to select the first row of each group?](https://stackoverflow.com/questions/33878370/how-to-select-the-first-row-of-each-group) – user10938362 Mar 25 '20 at 18:02

2 Answers2

1

Python: Here is the conversion of my scala code.

from pyspark.sql.functions import col, first, count, desc, row_number
from pyspark.sql import Window

df.groupBy("id", "diagnosis").agg(first(col("diagnosis_age")).alias("diagnosis_age"), count(col("diagnosis_age")).alias("cnt")) \
  .withColumn("seqnum", row_number().over(Window.partitionBy("id").orderBy(col("cnt").desc()))) \
  .where("seqnum = 1") \
  .select("id", "diagnosis_age", "diagnosis", "cnt") \
  .orderBy("id") \
  .show(10, False)

Scala: Your query does not make sense to me. The groupBy condition leads to the count for the record always be 1. I have modified a bit in the dataframe expression such as

import org.apache.spark.sql.expressions.Window

df.groupBy("id", "diagnosis").agg(first(col("diagnosis_age")).as("diagnosis_age"), count(col("diagnosis_age")).as("cnt"))
  .withColumn("seqnum", row_number.over(Window.partitionBy("id").orderBy(col("cnt").desc)))
  .where("seqnum = 1")
  .select("id", "diagnosis_age", "diagnosis", "cnt")
  .orderBy("id")
  .show(false)

where the result is:

+---+------------------+---------+---+
|id |diagnosis_age     |diagnosis|cnt|
+---+------------------+---------+---+
|1  |2.1843037179180302|315.32   |3  |
|2  |5.70572315231258  |315.32   |2  |
|3  |5.744053881894209 |315.32   |3  |
+---+------------------+---------+---+
Lamanus
  • 12,898
  • 4
  • 21
  • 47
  • I can't actually run your code, I changed the .as to .alias, and added \ where code was added to a new line, but I get an error relating to row_number: NameError: name 'row_number' is not defined, when I amend row_number to F.row_number due to import pyspark.sql.functions as F I get: AttributeError: 'function' object has no attribute 'over'. Is this something to do with different versions, as I'm using 1.6? – mad-a Mar 25 '20 at 15:06
  • @mad-a, sorry this is a scala code, I will update the python code. – Lamanus Mar 25 '20 at 15:20
1

You can use count, max, first with window functions and filter on count=max.

from pyspark.sql import functions as F
from pyspark.sql.window import Window
w=Window().partitionBy("id","diagnosis").orderBy("diagnosis_age")
w2=Window().partitionBy("id")
df.withColumn("count", F.count("diagnosis").over(w))\
  .withColumn("max", F.max("count").over(w2))\
  .filter("count=max")\
  .groupBy("id").agg(F.first("diagnosis_age").alias("diagnosis_age"),F.first("diagnosis").alias("diagnosis"))\
  .orderBy("id").show()

+---+------------------+---------+
| id|     diagnosis_age|diagnosis|
+---+------------------+---------+
|  1|2.1843037179180302|   315.32|
|  2|  5.70572315231258|   315.32|
|  3| 5.744053881894209|   315.32|
+---+------------------+---------+
murtihash
  • 8,030
  • 1
  • 14
  • 26
  • Although your code runs I don't think it's reducing the records per id to 1 making each id distinct. When I run: df.select("id").distinct().count() I get 154957, when I run a count() on your output I get 240438. – mad-a Mar 25 '20 at 15:11
  • @mad-a i see, i have updated solution per your feedback. let me know if u try it. – murtihash Mar 25 '20 at 15:18