2

How can I use the 'groupby(key).agg(' with a user defined functions? Specifically I need a list of all unique values per key [not count].

Hanan Shteingart
  • 8,480
  • 10
  • 53
  • 66
  • As far as I know, UDAFs (user-defined aggregate functions) are not supported by pyspark. If you can't move your logic to Scala, [here](http://stackoverflow.com/questions/33233737/) is a question that may help. – Daniel de Paula May 19 '16 at 22:55

2 Answers2

2

The collect_set and collect_list (for unordered and ordered results respectively) can be used to post-process groupby results. Starting out with a simple spark dataframe

    df = sqlContext.createDataFrame(
    [('first-neuron', 1, [0.0, 1.0, 2.0]), 
    ('first-neuron', 2, [1.0, 2.0, 3.0, 4.0])], 
    ("neuron_id", "time", "V"))

Let's say the goal is to return the longest length of the V list for each neuron (grouped by name)

    from pyspark.sql import functions as F
    grouped_df = tile_img_df.groupby('neuron_id').agg(F.collect_list('V'))

We have now grouped the V lists into a list of lists. Since we wanted the longest length we can run

    import pyspark.sql.types as sq_types
    len_udf = F.udf(lambda v_list: int(np.max([len(v) in v_list])),
                      returnType = sq_types.IntegerType())
    max_len_df = grouped_df.withColumn('max_len',len_udf('collect_list(V)'))

To get the max_len column added with the maximum length of the V list

Antiez
  • 679
  • 7
  • 11
kmader
  • 1,319
  • 1
  • 10
  • 13
1

I found pyspark.sql.functions.collect_set(col) which does the job I wanted.

Hanan Shteingart
  • 8,480
  • 10
  • 53
  • 66