81

How can I use collect_set or collect_list on a dataframe after groupby. for example: df.groupby('key').collect_set('values'). I get an error: AttributeError: 'GroupedData' object has no attribute 'collect_set'

Hanan Shteingart
  • 8,480
  • 10
  • 53
  • 66
  • 2
    Can you post some sample data that will throw this error so that we can debug your issue? – Katya Willard Jun 02 '16 at 13:24
  • in pyspark it works fine, btw I am trying precisely to translate this work in scala spark https://johnpaton.net/posts/forward-fill-spark/ (I mean the scoop of the job is backfilling and forward filling and this is how it works in pyspark) – Olfa2 Oct 12 '22 at 07:10

2 Answers2

156

You need to use agg. Example:

from pyspark import SparkContext
from pyspark.sql import HiveContext
from pyspark.sql import functions as F

sc = SparkContext("local")

sqlContext = HiveContext(sc)

df = sqlContext.createDataFrame([
    ("a", None, None),
    ("a", "code1", None),
    ("a", "code2", "name2"),
], ["id", "code", "name"])

df.show()

+---+-----+-----+
| id| code| name|
+---+-----+-----+
|  a| null| null|
|  a|code1| null|
|  a|code2|name2|
+---+-----+-----+

Note in the above you have to create a HiveContext. See https://stackoverflow.com/a/35529093/690430 for dealing with different Spark versions.

(df
  .groupby("id")
  .agg(F.collect_set("code"),
       F.collect_list("name"))
  .show())

+---+-----------------+------------------+
| id|collect_set(code)|collect_list(name)|
+---+-----------------+------------------+
|  a|   [code1, code2]|           [name2]|
+---+-----------------+------------------+
pault
  • 41,343
  • 15
  • 107
  • 149
Kamil Sindi
  • 21,782
  • 19
  • 96
  • 120
  • 35
    collect_set() contains distinct elements and collect_list() contains all elements (except nulls) – Grant Shannon May 03 '18 at 11:06
  • size function on collect_set or collect_list will be better to calculate the count value or to use plain count function . I am using an window to get the count of transaction attached to an account. – user3858193 May 06 '18 at 15:14
  • 3
    How to have the output of collect_list as dict when i have multiple columns inside list eg : agg(collect_list(struct(df.f1,df.f2,df.f3))). Output should be [f1:value,f2:value,f3:value] for each group. – Immanuel Fredrick Mar 12 '19 at 13:59
  • While performing this on large dataframe, collect_set does not seem to get me correct values of a group. Any thoughts? – haneulkim Jan 25 '22 at 23:44
-4

If your dataframe is large, you can try using pandas udf(GROUPED_AGG) to avoid memory error. It is also much faster.

Grouped aggregate Pandas UDFs are similar to Spark aggregate functions. Grouped aggregate Pandas UDFs are used with groupBy().agg() and pyspark.sql.Window. It defines an aggregation from one or more pandas.Series to a scalar value, where each pandas.Series represents a column within the group or window. pandas udf

example:

import pyspark.sql.functions as F

@F.pandas_udf('string', F.PandasUDFType.GROUPED_AGG)
def collect_list(name):
    return ', '.join(name)

grouped_df = df.groupby('id').agg(collect_list(df["name"]).alias('names'))
Allen
  • 49
  • 3
  • 14
    I do not think a custom UDF is faster than a spark builtin – jwdink Oct 18 '19 at 17:22
  • 1
    I _know_ that a pandas UDF is way slower than a spark builtin (and also, that a pandas UDF requires more memory from your cluster)! What's faster, pure java/scala, or java that has to call python on a data structure that also has to be serialized via arrow into a pandas DF? – Marco May 09 '20 at 07:41