Expanding on my comments, here's a solution which demonstrates how you can technically achieve this in two passses of the data - one to count, and one to reduce and find (multiple) modes. I've implemented the second part with the RDD API - translating into the DataFrame API is left to the reader ;) (tbh I don't know if it's even possible to do custom aggregations with multiple output rows like this):
from pyspark.sql import types
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
# Example data
data = [
(0 ,'12.2.25.68'),
(0 ,'12.2.25.68'),
(0 ,'12.2.25.43'),
(1 ,'62.251.0.149'), # This ID has two modes
(1 ,'62.251.0.140'),
]
schema = types.StructType([
types.StructField('id', types.IntegerType()),
types.StructField('ip', types.StringType()),
])
df = spark.createDataFrame(data, schema)
# Count id/ip pairs
df = df.groupBy('id', 'ip').count()
def find_modes(a, b):
"""
Reducing function to find modes (can return multiple).
a and b are lists of Row
"""
if a[0]['count'] > b[0]['count']:
return a
if a[0]['count'] < b[0]['count']:
return b
return a + b
result = (
df.rdd
.map(lambda row: (row['id'], [row]))
.reduceByKey(find_modes)
.collectAsMap()
)
Result:
{0: [Row(id=0, ip='12.2.25.68', count=2)],
1: [Row(id=1, ip='62.251.0.149', count=1),
Row(id=1, ip='62.251.0.140', count=1)]}
Small caveat to this approach: because I aggregate repeated modes in-memory, if you have many different IPs with the same count for a single ID, you do risk OOM issues. For this particular application, I'd say it's very unlikely (e.g. a single user probably won't have 1 million different IPs, all with 1 event).
But I tend to agree with @absolutelydevastated, the simplest solution is probably the one you have already, even if it has an extra pass of the data. But you should probably avoid doing a sort
/rank
and instead just seek the max count in the window if possible.